{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Chapter 16 – Reinforcement Learning**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook contains all the sample code and solutions to the exersices in chapter 16.\n",
    "\n",
    "<table align=\"left\">\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/ageron/handson-ml/blob/master/16_reinforcement_learning.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
    "  </td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Warning**: this is the code for the 1st edition of the book. Please visit https://github.com/ageron/handson-ml2 for the 2nd edition code, with up-to-date notebooks using the latest library versions. In particular, the 1st edition is based on TensorFlow 1, while the 2nd edition uses TensorFlow 2, which is much simpler to use."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import sklearn\n",
    "import sys\n",
    "\n",
    "try:\n",
    "    # %tensorflow_version only exists in Colab.\n",
    "    %tensorflow_version 1.x\n",
    "    !apt update && apt install -y libpq-dev libsdl2-dev swig xorg-dev xvfb\n",
    "    !pip install -q -U pyvirtualdisplay gym[atari,box2d]\n",
    "    IS_COLAB = True\n",
    "except Exception:\n",
    "    IS_COLAB = False\n",
    "\n",
    "# to make this notebook's output stable across runs\n",
    "def reset_graph(seed=42):\n",
    "    tf.reset_default_graph()\n",
    "    tf.set_random_seed(seed)\n",
    "    np.random.seed(seed)\n",
    "\n",
    "# To plot pretty figures\n",
    "%matplotlib inline\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "mpl.rc('axes', labelsize=14)\n",
    "mpl.rc('xtick', labelsize=12)\n",
    "mpl.rc('ytick', labelsize=12)\n",
    "\n",
    "# To get smooth animations\n",
    "import matplotlib.animation as animation\n",
    "mpl.rc('animation', html='jshtml')\n",
    "\n",
    "# Where to save the figures\n",
    "PROJECT_ROOT_DIR = \".\"\n",
    "CHAPTER_ID = \"rl\"\n",
    "IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
    "os.makedirs(IMAGES_PATH, exist_ok=True)\n",
    "\n",
    "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
    "    path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n",
    "    print(\"Saving figure\", fig_id)\n",
    "    if tight_layout:\n",
    "        plt.tight_layout()\n",
    "    plt.savefig(path, format=fig_extension, dpi=resolution)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: there may be minor differences between the output of this notebook and the examples shown in the book. You can safely ignore these differences. They are mainly due to the fact that most of the environments provided by OpenAI gym have some randomness."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction to OpenAI gym"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook we will be using [OpenAI gym](https://gym.openai.com/), a great toolkit for developing and comparing Reinforcement Learning algorithms. It provides many environments for your learning *agents* to interact with. Let's start by importing `gym`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we will load the MsPacman environment, version 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make('MsPacman-v0')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's initialize the environment by calling is `reset()` method. This returns an observation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.seed(42)\n",
    "obs = env.reset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Observations vary depending on the environment. In this case it is an RGB image represented as a 3D NumPy array of shape [width, height, channels] (with 3 channels: Red, Green and Blue). In other environments it may return different objects, as we will see later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(210, 160, 3)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "obs.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An environment can be visualized by calling its `render()` method, and you can pick the rendering mode (the rendering options depend on the environment)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Warning**: some environments require access to your display, which opens up a separate window, even if you specify `mode=\"rgb_array\"`. In general you can safely ignore that window. However, if Jupyter is running on a headless server (ie. without a screen) it will raise an exception. One way to avoid this is to install a fake X server like [Xvfb](http://en.wikipedia.org/wiki/Xvfb). On Debian or Ubuntu:\n",
    "\n",
    "```bash\n",
    "$ apt update\n",
    "$ apt install -y xvfb\n",
    "```\n",
    "\n",
    "You can then start Jupyter using the `xvfb-run` command:\n",
    "\n",
    "```bash\n",
    "$ xvfb-run -s \"-screen 0 1400x900x24\" jupyter notebook\n",
    "```\n",
    "\n",
    "Alternatively, you can install the [pyvirtualdisplay](https://github.com/ponty/pyvirtualdisplay) Python library which wraps Xvfb:\n",
    "\n",
    "```bash\n",
    "python3 -m pip install -U pyvirtualdisplay\n",
    "```\n",
    "\n",
    "And run the following code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import pyvirtualdisplay\n",
    "    display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()\n",
    "except ImportError:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.render()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example we will set `mode=\"rgb_array\"` to get an image of the environment as a NumPy array:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(210, 160, 3)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img = env.render(mode=\"rgb_array\")\n",
    "img.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's plot this image:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving figure MsPacman\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAANkAAAEYCAYAAADRfVPTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAwAklEQVR4nO2deXgcxZn/v9VzSxrdlm/L+MS3MeYwEAyY0ywBkiVLSLJhQ3YDJJuQ/GBhybWQiwAhQAg5SALZBAeSwCaAOWIbzI0A37fxIduyZd0zGs3d3fX7QwarulpSq6d7ZmS/n+fxHy7V9FS/Xd/unnrrfV/GOQdBEO6hFHoABHGsQyIjCJchkRGEy5DICMJlSGQE4TLegf7IGKOlR4KwAOec9fe3AUX2t6uucn40BHGcMaDIFo4ZI/yfc+Cx3SXI6P2KtiipC2q4ckJKaItkGJ5sLCnQiI4/Pj0xgXK/+GL01P4g2lOeAo3IHgEPx+cnJcCGIIEBRWaEA7hnaxjd2eH1U+6k6owkss60gu9tKi/QiI4/Lh6TQrlfE9p++0EpNkb8BRqRPSp9Ov51UgJDecwML7UQxDCEREYQLkMiIwiXIZERhMsMaeGjP6r8OhSIK0eRrALN4Doo9+nwGVxvMVWRVitLPDpCHrFfUmNIaOI9IaBwlHl1oS2jM8RUe/cOBo5qv3g8DqAzowDCT93efsYfv50ZBdzQWunT4TGcczSrQDXYJuzV4VfEfj2qgrSFlVwzO6R1hh6DHbyMo8In9lM5Q9SwkKWAo0qyA0NnhsGuHawS9unws8HtYHWO+BWOsME2Wc6kxTsP46g02EYDQyST+3PIEZG9uKQNY0LiAC9YWYvt3T6h7ZHTu3DGiIzQdkNDJZ47GBLabp0dwxcmJ4S2h3aU4sdbxNXAK8Ynce/JUaFt5eEA/u2talvnUR3Qsf7SVqEtoTHMfGYktD7XU2HAWxe3odQrXuQFy+vQlhaXpP92bjsml4mralesrsGaTnFV7f5TIrhwdFpou3VtBZZZcDP8y8QEfjC/W2h7/mAQX2qoEtpOrsngr2d3Cm07u71YsnKE0FYX1PHeUtEOMZVh9jMj0fcqexnQcEkrgoZV+LnPjURXxp7IHj41gnNGina46f0KPLVftMNNM3pww7S40PbIrlLcuVGcI0vHpvCzUyJC2+utflzzRo3QdmK5iheXtAttTQkPFr1YZ+c0BBwRGQMs+Q2s9kMO/Zzw4PU9Zn/HY0zsN1BYntPnnK/vsGIHDMEOlr4T7s+RXPsNFUdEds+WMEoMd/UWEyfjb3eXYvnBoNC2OeKT+i1vCmFPTBzaJpN+azp9+OY68c7VlLDv3IyrCr61XjyeygHdMHE4B+7cGIbXcFGMr2cA8MC2MMoNryH74/IYH99bgtdaAkLb+x3yOZvR0O6X7LDP5Dsae7xSv4iJz7M7yyQ7ZHUG3dBP48AdG8rhMdghodqfrY/tKcGKZtEO6ztlX9pLh4I4aLjWW6OyvTZ2yXOkOSnbpjmpSP16cjiPvrCBIqMP3nST8EedA7OfHTksndHPntshtO2JeXD2P3J/FSCs8eZFrag3vDYvXVUzLJ3RGy9rgWLQ39j77+9XkcNLLQQxDCGREYTLkMgIwmVIZAThMo6sLlplXIkqrUJapTOtoD1dLGERHFPCqvTj91hG5wy7Yh444yTJndqAhuqAcb3TGgmVoSkhTv2AwlFfpgptWZ1hb0/uEsmryO47OYoz6jKDdzThZ9tlZ3ShUBiw/LwOyRl9LBPLMswyOKMLyX9MjePG6fHBO5rwWovsjJ4SVvHS+aIz+kDcGWc0vS4ShMuQyAjCZUhkBOEyJDKCcBkSGUG4TF5XF53msnFJfHN2TGh7o82Pm9dU2jpelV/HC+eJK0xJDTh/5QgpNo7ojcF6+YI2BPrcqjmAi1fVSjFqVrl/YQSn14or0HdsDOOFQ6F+PlH8DGuRlXg4xpWKm05ru+0vMiuMY2yJJoQ85LKj/FiHARhToiHUx33JOXLyH9YGdOma2vWtFgvDWmQtKQWvtYi7uM1CZ6yS1RlebxWPl9YZhvcldg8O4M3WgBTRnc3BmbYp4oViiIxuzUNuxrjKpLlkDMC1y7AW2eqWIFa3BAfvaJHurCI5KYn+0TjDtTaj0PujUBsOGuNe1649LXwQhMuQyAjCZUhkBOEyJDKCcJm8LnzogJBabSjYzeMH9C4rG783p93kR45n91yGI7meq9m1z+mamhzP8ljyfN3ymkgn5OFSok+rZHRmu2STAi75WnQOKRGmdThKvblMkeEHR+8yt914shKPLvnPEiqDnBrVGn6FS64Dq6icIaXZ+147iXTy+iRLavYvUi7oYI6l9+qFHZlwhFXs39DMyeWmm2/oNxlBuAyJjCBchkRGEC5DIiMIl3Fk4eObs7tRYSi1c9+2MA4bco5/aWoPJofFjEB/3FuCjV3ixsyPj0virDqxssfLh4N48ZC4T/GUmgyuqherv2yL+vDo7lKhbUKJiq+c2CO0tacV3G3YJ1fm1fGduWJ1lKzO8O315cIqGAPH9+Z3S6tb39tYLpVtumVmDCOC4q7yn+8ow764aPprJ8UxszIrtD21P4SGdjEvvBmLatO4ckJSaNsU8eEPe0Q7nFCm4oZpoh1aUh78ZGtYaCv36fjWHNEOaY3hOxvKhWV3BRzfn98Nr8EO/7OhXFrouG1Wt5Rd6sHtZVLWqOumxDG9XLTDE40lWGvIh790bFKq/vJ6awDPNokhMSdVZfDpE8Q58kG3F4/sKhPaxoQ03DRDDJvqyij40ebc91I6IrLLxycxpkQ04O92lUoiWzIqLWWreq0lgI1d4vFOrsngmhPESdORViSRTSpTpX4rm3VJZDVBXeq3J+aRRBbwcHx6YlIKdfnOhnL03YrPGHBVfVLIVsV5b+GNmHgPwaXjkpgSFkX2ZGMJ9hkSLX1sZBoXjREnzfpOPxrE8DZTppbLdljexCWR1QU1qd+OqFcSWYmJHWJZhu9uKBciEhQGXDUxIYW6/GhzORLiKeOycSkpF/4f95SgSZz/OHdkGueMEu3Q0O6XRDa/KiudS4+qSCKrL5PP+bUWvySyKr88Rw7EPcUjsrjK0JMVl1PNnL1JTe5n5lBMm/QzW65VOaR+Zv4P3aRfwqQf55CW5pNm/pQj/YwuRjOvTcLMNiYdUybnnLXoBsrq8mdTJhdA43I/MzvokO3Qn8sinlWg9Tkhjn7sYHrt5WMmNflaqSbXPmNyzhlN6gZVtzhHYDJHiqmqiwIuub967c5s9WPgUq0ozuUdAlb7AVwOJOQwcYT21884Rt77Y7YA52zG8WiH/J2z2FgwZ7QOZn77stmPQ35K5NIPYBa30gyhX+8ABsXpc87ts8eOHQp3zkOHVhcJwmVIZAThMiQygnAZEhlBuMyQFj4YgEdO74Jq4QfibWsrcCBhb13lc5PiuHhMSmj7+4EQ/ryvxNbxJodV3DkvauuzaY3hi29X2Q7JuGdBBGNKTNaWi4SmhAe3rq209VkP4/jNoi7bISffWl9huzTR1RMTuGyc6Nd6/mAQj+8t7ecTA1NfquKHJw0+R7xs6HEkQxMZA860WPool7JCU8IqFo8Uv2djl/1Ub+U+XTqeVRIq610qtnk6p9RmJGd0MbEjan+BmQE4qy4tOKOHQthrP3R2Upk8R3Z0258jZV5ue44MBr0uEoTLkMgIwmVIZAThMiQygnCZvOb4WFCdQZUhJGZTxGc71/mYkIYZFWJYRHtawQZD6IzzcJwzMg2PYZnp9dbAsMk7MRABhUuhRioHXm0JwO0cLfOrMqgxhMRsifqkiA6rjApqmGUIIerMKFjX6fYcOUpeRXbbrJgU6nL9O5V47qC9sjgfq0vjJwvFZdeVzQHH87MbURjwq9MjUqjLguV1jhUpKCRVfh2PndElhbrkozD7zTN7pFCXr71Xgaf223PfnD4ig4dOjQhtZoXZ3YReFwnCZfL6JPvGmgop/+GhhP07/wuHgli3QvSNxG3mhBwKOgcufblGCnfozBwb96z2tIIlK2uFNp0z159iAHDrunLJx2r3VREAVjUHcN4K8VzyXXMuryIzhprnSndWsZ1oNTcYdsXsOz6LHZUz7MzBsZsLBx2eIzFVQay7sDe/Y+PWSxBFDImMIFyGREYQLkMiIwiXGfBX5mO77fkmgN6cdcVCW0qxfS4Z3X7uDQD4v/0hjAjmY13OHi1J+9dJ58Af95TCZzPUpb2IfIqdGftzBAC+OcDfBhTZt9ZX2P7SYqIp4S3YuTywPTx4p0Fg4BjlS0nth7PBnGp85YoOhjs2FqaQutM0Jz05zRHbIiOKAY6gouGv0xqEbVwqZzh/61lIcQWFKEdFWKd43ukIUxaURvD0tAbpQnnA8X/T38H8EnsR30T+oCdZEXNZVTMuqmhBiaLijqYToYPhvPI2TAv14JctJwAAzq1ow1h/Essjows8WqI/6ElWxMwKdWNhWQQaV/CP6Eh0qH6oYOjWvFgRrUNC92BBaQQzQ7HBD0YUDEeeZGGfLqk1pjLohlznpV4dXpO6wVmTnOhW8CkcJR5xZSury6VTPYyjzKRmtLECCwNHuU/sxwF0Z41leHv7GUfdnWXSQkTYK9dK7lGZlAe+xKPDZzCiMUHNv9c1YkYohr3pEigA/mvMTlR7s9ieKkeFT1zBtGoHjfcWauhLIe1gOkc0hqzNECKzOaJyIG44Z4VxhC3METs4IrKV57dhdEi8yBeurMV2w/633y3qwqIRYqjLjQ32Q10+MT6Je04Wf5Osag7g394WQ13mVmbx93M7hLa9PR4s/ked0FYd0LHu0lahLakxzHxmpFAYQ2FAwyWt0kbWk01CXZ45tx2TDYl0rlxdgzWGeKYHT43gwtFiiEekJYlEn9O7fu98/HtdI84Md0ADcPmORXhw4gZcXZ/El05tET77/MEgrm+oEtoW1mTwl7M7hbad3V6cv3KE0DYyqOPdpaIderIMs58VQ128DHh/aSuChkk877mR6MqIonhhSTsmlIp2uPTlGmyKiHb45WkRLDaURPr6+/ZDXS4dm8KDp0SEttdb/fiMIdRlRrmKF5aIJXSaEh6c8aI4R+zgiMgYICfrz6HfUDAez1iE4MMvlvr1czxmOEZ//RTDMfvzpTGT7+7ve/sbY1DRsGzqewAHKr1ZhBQNT0x5DwAw0p8CWLnl87Nth346Msat2cHid+dljuTYb6g4IrIHt5dJryGtKfkx+4e9JXj5sFjUblvU/m7vDV0+fH+j6IfaF5cdnAcTHqlfxGT3fkJl+MEmsV+Wy4UIOAfu3hKGj4l/6DEJofjlzjJUGl7lmkzCe/6yrwTvtYt39TO8HZjm6b34EwNiIa+JwaP/X9Phwzst4rj3mOQz3B/3SnYwC8+JmdghozMpK57OgR9vDguvdxxA0sQOP99RhnKDHZpNQliWNYbwRqtoh005pAPcHJHP2cz+LSlF6ufEqyIwSOkkNv7HLtW5GJg75kVx3RRxUv1seyl+vOXYcHxa5bYxO3BFdfOg/Z7qGIN7mqflYUTFw+2zu3HjdLGS4q8/KMWdBXKO8wO39vvgo9VFgnAZEhlBuAyJrMjZmgjja41zoXPg+03TsbxrJPanQ/jy3nlI6wz3N0/GnzrGFXqYxADQjo8i5v14FXamyrA2Xok/d4zDOz3V6FD92JMuxdp4Jf7SMQ5vxmrQlLG/e5xwnwFFNrFUtX3gg0mPbQeiGRV+ntN47MLx4YqlvXMZG1IlJ7NVdqnVgAqMLdHxt9hElAWAZl6B5mQF6ks1PNdTD48PmOizb5eMznDIdqIajvpSrSDbkyv8zoYP+RWOMSF3CoMMKLI3Lm6zfeDzV8jO6Fz410kJ/OukxOAdHSahMswwOKOHwuMf6yz6qi5LDM5oq3gZsPKCNttVXYqJqWEVL53fPnhHG9BvMoJwGRIZQbgMiYwgXIZERhAuQyIjCJfJq5/sgYURnFQthrp8d0M5XmkJ5nMYOaOA44Ul7QgZQjyuWF2DzszgS203NFRicyT/abDnVWWlCidm1AY0PL1YDA2KqwxLX64taOIeO1wwOoVvz+kW2t7v8OMbayrzNoa8imx0SMMkw3J2LgXcCwYDJpZpUukkY72y/jiU8GCvyS55t6kLWnMleBlwQpkmlU7KoT59wSj1cmnONSXy61LJ65Xe1u2FxxDtOywroXDg/Q6fFKzopPO9kGR04N0O8UmbVJVhJzAA6EgraGgXz8VJ/60V8iqy7244NvI46mBSZO2xRGfGg0++Wjt4x2HA660BvN4aGLyjiwzDxwhBDC9IZAThMiQygnAZEhlBuMyACx/GBDL9YcxslCucW/9utymSYRQtnDPouZS9cRCnh8G5M9d/QJHNfXakpYOsuEDOu5gLP94Sxv3bcq+G4gQcsB3mcqyjcuCU5+uKxj2ddtiFciip4CKLYUBdD/T/twFFZpY2zQyn7yBJTUGyeEOwiI9giGaLRWLOo3NmWQMDQb/JCMJlSGQE4TIkMoJwGRIZQbhMXvcufmFyHJPCYmalJxpLpLCPpWOTOMNQ/eXVlgBWNIshMQuqM/jEhKTQtj3qxR/3lgpt40pUXD9NTOncnlakFcxSr47bZseE1bKMzvD9jWHofVoZOL41J4aAYYPwXZvDUhmi4UjYp+PWWWLNs5TWmx+/b6iLAo5vz40Jhdk5B360OSyVbfrGjBiqA+IK9MM7yqRMWZ+fFMfUcnGO/HVfCOu7xPz4F45O4WxD9Zc3W/144ZBYIWhuVQafqhfnyK6YF4/tFueIm+RVZBePSeGMOlE877T5JZGdVpvBtZPFzFSxLJNENjWsSv1WNgckkY0I6lK/PTGPJLKgh+PaSQnB5/dREYo+emIM+OykhBTq8sC2MvTkP2ud45R6OD5vsEMsy/DDTWHBb6Qw4LOT4kK2Ks6B+7aFYYwm+eSEJOrLxMY/N4YkkV0wOo1zRoniWdfpk0S2sEaeIxmdSSKbVKZJ/V5r8R+7IsvoTFqaN3M6Z036qSY+EI1D6pcx6aeb9OvPp5LSIKRYTJm5EnjvnV0xVHVx053mV5ih/B6QccljzwGkDG7PlNafvQxRZv0MKW127U08bBldvlbGQoFA73yQ54j8vWbX3myOuIkjVV3evaQFY0rEMzTLu+hhclytxiFF2yrgUq0oncsXhYFLgZK9zmPjt3CpeqPVfkCv09VYYdJaP2D1ha1S3sXLX5GLAA5ESFHw/KLp8PQ5tso5LnlnO9JDENpptWk8tVgsAmied9F5O5hde6v98jFHZlVkpbyLB+IeLLJYBHCgqi55fZKZ3ZHM0CHXBDODgx25UINRqH65MzMcws2TR8PPGFif9zcPBx6eewLu3XUI23pSDn+r83aweu0LN0fcg3LhFzGnV5Xh3NpyTC4N4vGmDnBwzCsvxdiQH8+3dAEATqsqQ9jrwbuR+CBHIwoFiayIWVwTxmWjqpDQNPyysQWjgz6UejzwKwy/aGzFxJAft08bixq/j0RWxJDIhhE/mDEeU0uD2NGTgocBvzlpEkKKgh2Ovy4STjL8nTrHETdubMTjTb2p2jQOXN6wE+ui+S/CQQyNAZ9kv1vUOdCfP8LoZMyVayYmcP7o4rg7Z3SGGxsqTZeb80VAUXDXzAkAOCaEAqj2e3H3zPEAgEmlAexNpAc+gEt4GMfDp0YEZ3Qh+UdzEE80OlerrTagWdbAQAwosgvHFObiTa/IFuy7jSRU1uuULeA88jCGs2pEx/lZNYUvUs8AnDc6VTSlkxrjzv76CXmd0QC9LhKEy5DICMJlSGQE4TIksiJnbyKN+3c3Q+ccj+1vw1udMRxOZXDvrkPI6jqePNiBlW3RQg+TGIC8+snmVGZQ6RdXELZFvWhPF8kvZ8twnDEiI+2Je6fd72g+/MZEGlFVw98Od+H06jCea+nCvGQJDodL8LfmLpxeFcYLLRHsjDu7EutXOE6rFaMlNA681eaH3QL1hWJEQMOJFWJoRFdGyWtVnbyK7NtzYlKoy/XvVOK5g6F+PlGcKAx49IwuKdRlwfI6tDl4w3jy0NHl429s3gcAOJSK4oXW3ifXf23d79h39aXar2PZWZ1SqMusZ0bCWWeN+5xZl5HKRb3W4sc1eaxlQK+LBOEyeX2S3bquQqpHdiA+3F4Ve0MqrlxdI4VadFksA3XfwggSav5fu0os1oJrTyu45GWxqktvGMnw45XDAVy8SjyXnjynscuryApR+M4dGLZG7b/TTw4Xd1JJlbOCVAJ1g2hWQTRS2Bc2el0kCJchkRGEy5DICMJlSGQE4TIDrkQ80Wjff+VEov6+bI14sbEAP8YzGsupoMbzB4OoCxbvutzhpP3VXR3AX/eVFCTUZV5VFjMqnMu/15VRcprv/2+Avw0ospvXVNr+UqdZdTiAH28pfHjHULnbgTEzcNSWZz7aa5HVGLriPowIZxBJ+JDVCvNConOG/15XUZDvvn12t6MiO5T05DTfbYuMKA5Cfh3/uK3hI7/ctkNl+MzP5+O5W97DV38/Cw27qwo7QGJA6DdZkXPyCRE88//eg8J6MxczBkwdFcfyW97DPz9wMq4+4xD+88K9hR4mMQAksiLH79VRV5H5aB/hK1tr8LvV4zGqMo3/OG8/1u2twOs7qgs7SGJASGRFTiTuw+s7qvD6jirEUx7EUx50xn1QGHD5yS3Y0VyK9fsK87uIsIYjv8lKPLpUmD2pMin5TNDD4THkj09pzHLWWCNexqXKKprOkDKEmyjgCBn27fXmSBfvMQzcdH9fXGUwpqcu8crppBMqk9JJhzy6tMfRzDYBhcNrWKVLawzbDoXx5UfnAAD+94Z1OHdWB5bMFtNJm9lB1ZmU719hHCGPPTtwDiQ02Q7GvaiAmb1ymyNpjUF1co5wJuX2N5sjveec+3PIEZGtuqAdY0LifryLVsm58B87oxOLDHFKN75bieU2Q10+MSGJexaIAYurDgfwhbfF16e5VVn8/ZwOoW1vjwfnrBDznFcHdKxZ2ipc9oTGMPvZkUJxdgXA+0tbUWq4eAufl0Ndlp/Xgcll4irYJ16Vc+E/dGoXLhwtJm25bV0F/tQn+9K1v5wPAJgxtgd/+sq6j9o/PTGB78/vFj77/KEgbmgQF0QWVmfxl7NFO+yMeXGBIRd+XVDHu5e0Cm0xlWHus2Koi5cBay9tRbBv6SQAJy0fia6MOIlfOr8dE0rEOfJPr9Rik8Et8+vTurDYUBLp62sq8PR+e1moLh2XwoMLI0Lb661+fPZNMdRlRoWK588z5MJPeHDWS9Zy4Q+EIyLzMA6PBcErgNQvl/3QzOR4xqcG0LtYYKUfAHiOLC589H8zFxfr7df3mP350hSrtjEZo3GI+pG7uW4Yk+n5mXwHMxmLab8j4xHs0I+9jOfXrx1gdn5yZyt2GApW5whM+vV3zkPFEZH96oMyhL3ilW9Py5fvz/tCR6Jrj7Kj2/4QNkd8uHdLmdBmttO/OeGR+kVMwlISKsNPtor9slwubMA58OD2MvgMrzVxk/CVR3eVosov2sZYkwsA/nYghE1d4l3d6Hz/1OmHUFuWQW25+DawvlO2wwcx2Q5Nca/Ur8PEDjETO2R0WRI6Bx7YFobXYIekSZml3+wqRYVPtMPhlGyHv+4P4f0O8bxziXjYFpXPeb9JeFVrSpH6dTu0ocKR0klOc8e8KK6bImbG/dn20mHpjM6V0oCK8TW96QXu/vQ2TBwhVo3890fmHJd+sttnd+PG6WL+/19/UIo7NxZmjhRN6SRi6Myd0I1fXbdZauf8w7pdRLFDIhumaDqw5IenI5o4NoIrj2XIT1bkbD0Yxtf+d6aw2LGvPYgbH52DaML30WIIUbzQk6zIiSZ8eHNHNf7y7uiPalQ3R4J4Z9fx9ztsuDKgyIy+r6HQmlJsOxCdxqdwjDBUnlE50GqyumUNjtEhuc7L4aRiu/pLtV9DcIDhPPrSJOH/g12blMbQaTGxjxGFcYwyhOdwAM1JBXYX1EcGNWlJvC2lIFssc4RxjHApJGlAkb27tHWgPw+IWWH2QjG7MotnzxWdsHtiHpz9D3uORoUBqy9sczTv4j0nR3GRg5VsljcF8aUGe0+7uoCOhktaHc27+PTiDtSXiTeGpatqsDFivUC9m0wrV6XC7E5Bv8kIwmVIZAThMiQygnAZEhlBuAyJjCBcJq9+sp+cHMH8qqzQduemMF5tCdo63sVjUrhlZkxoe6fdj2+utxfEWOnX8ZQhFCSlM3z8lRrbMW+PndGJ8YYQj6++V4ktOWx6tcvcygx+ulAMDWqMe3CdITTIKh7G8ey57QgYbtVXvlpje3PtXSdFcUqNuAH6ri1hrGi2N0eWjErh9tniHFnT6cN/ra20dTw75FVk40s0TDdkGApbLIJgRoVPl453IGE/xZmHcUwrV4Wl61wLQ0wsUzHFkPs+6CnMjsOQl0v2ygUGYGq5KhRm5zy3EJFxJnOk3GfffxX2yefcksrvC1xeRbarxytF3OaSn7Ezo2B9p/hEyKWohaozbDCEm6Q1s6gn62yP+tBjOMdCVHQBgJ6sbC+zsA+rcACbunzwGy6hloPB9vR4UGkYo9VqOWZEMkw65z15LnyS12+73eEcfSuag7ZfI8yIZhX80yu1g3ccAtfbdAi7wZaoz9Hz0zjDJ1511l7f2eDsHFndEsRqmz9HnIIWPgjCZUhkBOEyJDKCcBkSGUG4zJAXPnKpcFKM31NMHI/nnAtO28vq8Yz5IwdjSCLTOXD2SyMQs7AEbZYNyip3bwnjwe1i5qBkgZa988VN71fC72AJooxJxqhjifu3leFXH5QKbcaEpUNhe7cX85cPHvpU4eN49cK2IQltyE+yzoziWKqs/oirCuLO+UyHBTGXbXqskdAUJBysb69xhg4LsYCaroNjaKGrdGUJwmVIZAThMiQygnAZEhlBuIwjexf/c3oM5X5xZeyRD0qlbFCfPSEuJVN5en8I2wxhHxeOTuEUQ/WXN1v90h60eVUZ/NO4lND2QbcXf95nrwJIiUfH12f2CG1ZneHeLWW2s1BdP60HNYZMWb/fXYKmhGj6q+oTmFYurvY83xTEuq7BE82cXJ3BxWNFO2yPevGUoRLKhBIVn5sspj9vSyn49QfiSq5VFHDcMisGb99bNQfu21YmlWOyyqcnJjApLNrhmQMhqfrLeaNSWDRCnCMN7X6sNOxlnVWRxRUTxNTme2NeLGsUbTMyqOGLU8W039GMgod22LNNXxwR2ecmJTCmRJxIT+0LSSL7+LgUzqgTDbOh0yeJ7My6tJQLX9UhiezEchU3TBMNs7I5YFtkIS/H9VPjUqjLT7aW2c6HffXEhBTq8uLBIJrE08PFY1JStqq9Ma8lkc2qzEp2WN4UlEQ2ukST+u2Ieu2LjAFfnBqXQl0e3lmGpM2Vv0vHpnDOKNEO26NeSWSn12akc/EwSCKbajJHXmvxSyKrDehSvwNxT/GILKszZAwhP2ZzUuWw1E8zOZ5Z0KQO+XhqjqnzsjqE9dlsjsdTLdtG7mf1q3UTu5qFm3CT71BzdM1ldSYW7cvxeGZzxCxLsmZ2ziYG4yb9zPKBcsj9cr32HzJgVZeDN90k/FHnwOxnR0p+Mr8iV53M6JCqTvoYl2pDZXVIr2JexqXAP5XLQlPA4TO8legcUsLMk6ozFvMucinK96jxjx5TYRzbPt5iKe+iVduYnXOWm08wIwrj8Bm6aVyeTAxciv3S0SuUvowKanhvaX95F8VKm35F9hmlDfYCgDcvarWUd9HqHPEwDq+Dc8TMNr3XXuxX6dOx8bIWaYxj77/f3aouxoH0R5YzS3c6lTNLd1gd7MgFdQqnj2fdNlbP2QydM6QtfJa7YC/j3T9XrM4RjTNLwaFW54jztjkKrS4ShMuQyAjCZUhkBOEyJDKCcJm8JtK5aUYM08vtba9/8WAQf28KOTwie+gcuOm9Cmk1MJfohOun9WCeISdlLqzt9OERm/6vSEbBDQ2VQpvKme2KLm5w5fgkLhyTGryjCduiXjy4PezwiPonryI7ozYjOaOt0tjjKRqRAQwvHHJ2LKfUZBwtnaQAeMTmZ1M6w3MHi8XW5syoyOKycfZEVuHz40GHxzMQ9LpIEC5DIiMIlyGREYTLkMgIwmUcWfiYX5WR9n1tinhthztYpTagYZJhP1wkw7AzZq9iipdxLKgWV/g03lsFBIY9eydXZ6XVxXWdPmlP3JzKjLBLHQC2Rr3oUfN/fwt7dcwwFF9Iakza4e5TOE6qGtwODBwn12SlO/XaTp/pJlwrTCvPotIn7pfa3eOxlH8jF0o8OmZXirZJ68AGC1EQg+GIyH59epcU6tJbmN3dibRkVBo/MZQCWtkcwLVv2SsFVOHX8dTiDinUZcYzI4V9cgoDln2s09IG4Z+dGpFCXS5/pQZrOvNfkHxmZRZPLe4U2nZEvViycoTQVmNiB7PC7B4G/OljHVKoy9znRqIrY09k35kTk0JdvvZehRS24zQnlGl4+hxxE/mBuAeLXhw8g9Vg0OsiQbiMI0+y696ukl4X9+VQkscqKw8HcPkrNUJbNAeHcDSj4IrV4vF0yLFZOgeufq1aCncwK/Hz5YYqqR7Z9u78lu75kC0Rn2Qvs1yFHWnZDhqX49s0DnzqtRrpTh3L2s9/eMfGMH66TXSiN/a4P5f29Hgk21iNoBgMR672pkj+X30AoCPt7Lu6ypnF1zhmKWIZQEEqavZHj6pYOr+sRTtwMKxz+LX3A5u/p3MlqVmzjR3odZEgXIZERhAuQyIjCJchkRGEyxRmmasIKPNxXD4uOXjHIfDioSDSDq1IFZKgh+Oi0fZ2uPdHX59iMVLh03HOSNE/l9CYIzXJj1uR1QV1/Py0iGPH688ZPRyp9Ol46NTIkOtwDWfGlWjSfDgQ9zgiMnpdJAiXIZERhMuQyAjCZUhkBOEyw3rhw2OWnhpy2ulC4le4dCdLm6TpzgfMJA15bz2B4rGXT+EwLh1lLKYrL1aGtcg+MSGJuxeIoS4vHw7gurfthbq4wfPntWOyoRTQJ1+twdoChLqcUpPFk2eL4Rw7u724aNWIfj6Rfx45vQuLDUvp33i/Ev93oLgT+wzEsBaZAkjFBIyBlIXGq8gFDwo1RMbksRiLNhQaL5OvqcKK28c2GMNaZNuiXjy0o1Ro212gMJL+WLa3BNWGIoCHU4X5KXwo4ZHs1V6gsfTHs01BbI6K13BHt/s789vTimSbqEnokh2Ka0YOkY0Rv1R6p9iwW2DPDQ4kvLhrc3mhhzEgT9os4JgrLSmPa7YprtsYQRyDkMgIwmVIZAThMiQygnCZvC58RLIKWm2uZsVzyFOY1Znt7x0KuVQ9iWbs28aM7hyS2WgcaEu7by9jjsqh0KPav6YRh1YNreJIYXbr5OrvsHtR8ulnKaYxFtNY+qNQY7T3vQUrzG6dgrlhC/S9Q6GYxlhMY+mP4TDGXug3GUG4DImMIFyGREYQLkMiIwiXcWTh43/P7MSIgFi55IaGKjTGxcPfdVIU86rEmtE/3FyO11sDQtuXpvbgivFiJqknG0vw2B5xA+eFo1P4+oyY0NbQ7sf/bKwQ2qaXZ3H/wojQdjDhwRffEUNiKn06ln2sQ/hJndIY/vm1Gmh9lpsVcPx1cQdCfXLccwCfeaNGyof/69O7ML5EDHX5xppKbDOk7/7O3G4sqhVDPB7YHsaLhwZP5HLp2CS+Mr1HaHujNYAfGPbiza7M4p4FEaFtX9yL6xuqhLaagIY/nilWf4mrCq56rVqIg/MwjqcWdyCgiHa4+vUaaQX6d4s6MTokzpGvvlcppeX+3rwoFtaIc+TerWGsOiza4QuT47iqPiG0Pb0/hEd2iXtFzxmZwq2zxDmyttOPb64X58jkMhUPndoltLWkPLYrBPXFEZGdWJ6VSicZiywAwKQyFXOqxAlX4ZO9S2NKNKnf6hZN6lfl16V+LSk5W1SJl0v9zFKUeRSOOZWqVDpJggGzKlWpdJLXJCRjWnlWKp1UYmKb+lLZNtV+a563moBsh/1x+dKWeuV+xkIhAOBjwGyDHWJZBgZx4ZwBmFWZlUonmYUbTS9XUW+oJRcyscMJZfK1rzSxw6iQ3O/tdrlfpV++9maFQYIeuV9l3BlXhiMi2xXzIpoVT9As/+C+uAdVhjCGmImTuSXpwTZDvzYT8USzitTvgEk1maTKpH5NCbmfpjOp4opZ1RPw3rpeIYNQzQrf7Yl5pUjtpMkxmxLyOUcsOpS7MrIdDpqcX0KV+5lVTMlyufJMXGWSZ4oD2BH1IWB4ohur4ADA7h4vEtrgdthvYgezSj2tKflcWpNyv+6sfO33m8yRtC73O5x0Jr1fnp3RBDG8seOMJrUQhMuQyAjCZUhkBOEyJDKCcBkSGUG4zJCW8BmAly9owwALkgRxTMPY0Pf/D01kDBgVyiU0kSCOPwYU2cuNjXkaBkEMbz43wN8GdEYzllvqVq+XweNhSKd1hEJHf/5xDqRSQ38iMgYEgwqSSR1+f++xPySZpCfssUhQUcAAaJwjU8S/UzjvP5eCq5HRn/nMGCxcWI6bbtqGJ56Yj8CRageHD6dx7bWbhny8+voQHn54Jj75yXX47/+ejFNPPbrJ85prNqCzM+vY2Ini4JGZM1Ht86EhGsWde/YUeji2cO1J9u1vT8aePQns3ZvE5z8/FhMnhqAc2YuSzerYvz+Fm2/ejp4eeeOvGRdfXIuzzqrCo48exNe+Vo/6+hBKSo7uLdu3L4nf/rYJb78dsTtkooio8nrxgylT8IumJvRoGqaEQriyrg4AcOeePTicyQxyhPwy0JPMtSX8tWu7oSgMp51WiYkTQ1i2rBnNzb2hHD6fgkmTQsLr3mAcOJDC9u1xnHVWFaZMKcFbb0Xw5ptHQxPq60MIh4d11nGiD2ldxxuRCD5IJLA3mUS3pmFcMIg3IxGcX1ODWaWlgx+kSHBNZMuXt8HjYVi6dAQ4B959N4LubhWRSBa7dsWxa1cCmtl27X7YsqUHb7zRhc9+dgx8PgU7dvRgz54EVFU/crw4YjF18AMRw4KErmPZ4cNI6Ud/a2eOtE0MBjG7rAxjA4EBjlA85OXWryjAT386AwDw97+34Be/OJDzMW+4YQI4B9rbM7jhhq05H48oXtiRfx/K7Yd79+ILY8fi5vp6fH3nzgKOzBp5EZmmAVdfvR7ptA5VdWaF6Hvf243334+SY/w44KsTJiDs8eBzmzeDA3jwxBNxQjCIDxKJQT9bDORtW1UqpSOVck5kmUzv8dJpWro/lrlt4kTsS6Xwp8OHwQB8f8oUjAsE4FUUjA8GcefkyfCwoe7ByC95eZIx1rs6qKocjY1JbNnSM/iHBuGUUyowYoQfqZSGVas6B/8AMSyZEw4jwzkOpFI4nMlgQTgM5YioyrxenBQOQ0FvrfBiJS8i83gYvvKVegDASy+1o7W1d/m1vT1j+3XviitGAgA6OjLYuLFXtLGYasvJTRQfCoAanw8eABfW1KDW58Mvm5rQlsmgxu+HlzFkdR1t2Wxek4rbwfXXRaMf7qKLarFs2Tw8/vhclJcPXePG49XU+LFs2TwsWzYPZ59dPAXZidyo9vnwhzlzUOnrzWa1oLwc902bhmu3bMGhdBqcc+xKJPCFLVugFvkPc1efZE8+2YzNm2P40Y+mO3K8AweSuOaaDfjtb+cIjmji+OKWnTvhAYpeXB/i6pMsmdSxc2cc9967F9k+2awikSzuuWcvEomhvUlrGtDensX99zdi9+6jK0ucczz88H5s3hwb4NPEcCKmabi3sRE9aq/vc3cigYcOHAAHEFVVdKoqurVi/iV2FNd/k3V3a1ixoh2nnVYBn69X052dWaxY0WH7mK+80olx44Joazu6tWblynbEYsPD6MTgpHUdKzs7MT8cRtjrxY54HKu7ugb/YBHi6i58gjheKMjeRYIgeiGREYTLkMgIwmUoNoQAAEwbV/rRTgpV49h1KF7gER07kMgI+LwMD3x5LkJ+D8CAzlgGV93xHjSd1r2cgF4Xj3PqR4bw/A8XoTTowXd/vw33/WUXasv9ePGuRait8Bd6eMcEtIR/HLNkwQh8avEYnDghjNt/sxWbG7vh9TDMnVSBOz5/Ijbv7cbvVxxAw7bh6Z/KJwVLpEMUN6OqAphZXw5d51i/O4p4qteZv2lvNwBgzqQK1JS3FHKIxwT0ukgQLkMiO84Z6OfCQH8jrEO/yY5jgj4F08eX4Wf/OReReBZ3/mEHKkp8uOmTk1FZ5sO/3bMW+1uSyKgUozcYA/0mI5Ed55QGPThzVjVu+Zep2LY/Bp9XwYS6EO5/ajde29iBdJYEZgVa+CD6JZ7SsGJNG06bUY2gv/fXQ+PhBFasaSvwyI4d6ElGEA5Au/AJooCQyAjCZUhkBOEyJDKCcBkSGUG4DImMIFyGREYQLkMiIwiXIZERhMuQyAjCZUhkBOEyJDKCcBkSGUG4DImMIFyGREYQLkMiIwiXIZERhMuQyAjCZUhkBOEyA+b4IAgid+hJRhAuQyIjCJchkRGEy5DICMJlSGQE4TIkMoJwmf8PN7z/q9pA4kIAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 360x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(5,4))\n",
    "plt.imshow(img)\n",
    "plt.axis(\"off\")\n",
    "save_fig(\"MsPacman\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Welcome back to the 1980s! :)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this environment, the rendered image is simply equal to the observation (but in many environments this is not the case):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(img == obs).all()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create a little helper function to plot an environment:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_environment(env, figsize=(5,4)):\n",
    "    plt.figure(figsize=figsize)\n",
    "    img = env.render(mode=\"rgb_array\")\n",
    "    plt.imshow(img)\n",
    "    plt.axis(\"off\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see how to interact with an environment. Your agent will need to select an action from an \"action space\" (the set of possible actions). Let's see what this environment's action space looks like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Discrete(9)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.action_space"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Discrete(9)` means that the possible actions are integers 0 through 8, which represents the 9 possible positions of the joystick (0=center, 1=up, 2=right, 3=left, 4=down, 5=upper-right, 6=upper-left, 7=lower-right, 8=lower-left)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we need to tell the environment which action to play, and it will compute the next step of the game. Let's go left for 110 steps, then lower left for 40 steps:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.seed(42)\n",
    "env.reset()\n",
    "for step in range(110):\n",
    "    env.step(3) #left\n",
    "for step in range(40):\n",
    "    env.step(8) #lower-left"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Where are we now?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAALQAAADnCAYAAAC313xrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAmgElEQVR4nO2dd5Qcxb3vv909eXZnZrNyFpJWWYCEJJAQCIkgk+zjDMbXARMMsp8NPC7BBMdn+2KMDdhckiMYIwEGARIKCIESKMfVKuxKm+PM7MQO749F0129o+7pyWrqcw4H1XZPV1X3d3oq/OpbjCRJoFDMAlvoAlAo2YQKmmIqqKAppoIKmmIqqKAppsKidZBhGDoEQik6JEliznRMU9ANd9yR/dJQKDlEU9AsQ34Rpr1RA3/8jF+OgvCDSUHcOSmYSK9rseObH5YVsERnDy/O78KCmlgi/dv9JXj8YEkBSzQQr03ErmVtKZ+vKWg1vATwqrf99IsmwVflSaSP7WtEw6EmI5fNCFGVljCwjJTkqNuTYpLnW2h40Vh5DAlazYyFtfBVecGycqajp4wAy7E4vv9kJpemUNIi7VGOmRdPhq/Kg32bD6OzuRsAUL/7BJrqWzFy4lCMnTYya4WkUFIl7Te01W4Fy7LgYzz2bz2CCeeOAc8LOLWvAaIowmLhsllOCiUlMmpyAMDE88dC4EXYnTZ4K0oxdOwg2OwWtJ/swqWDIrh3SiBxbkMfh29+VJ5I21gJqy7pIK63eE0lJMhNmJcv6kSFXW4p3/2JF9u7bCmXr9Iu4KWLuhJpEcBla6qIc1Zd0g6b4rfqGx+W4WRIvjX3TfVjUU00kX610Yk/HJI7T7XeOH5/fk8iHeAZXLu+UrNcr13cgRKL3Iq9basPB/3WRPr7EwK4dngkkX6vxY6f7ZX7KiPcPJ6b251Ix0TgirVkvVYvbid+gr+4sRyd0dRfNLMrYvj5zN5Euj3K4ssbKxJpFhJWLyaf3+VrKxFXtHufn9eF4S4hkX50TynWtToS6euHh3HbBLlTv7/Xgu9vS79Tn7GgnW4HkbY7ZbGVWiVM8PKJNKsa1mYA4vjpvynPGlPKY5BTFrTLYmxo3MKSeQjqXiSAczw87IrnbFM1xAY7BeIa1e0CcdzBkfXsiel3ZMaV8vDa5Lo4OLJeNU6RuOb+XvJR2VT1ipBFAtBfL05RFIvB/p7bQpbBE1LdGCb581MyuoTH2FK5cKVWsp4+G5lHkM+sU5qxoI/sOo6qoeXwVnrQWNcMq82CQSP73xSb2mz48vvyGzkkkIWNiyCOAwNHLW7dUgYbK9+Evb1WGKEryhJ5JPs63LCpnHiTNYfJt9jjB0vwj2OuRLpJdbzObyHy4FP4zn3rozJCYEcC5KN49ogbb5+SXxZtUVJMp0JkvZJ8T/HVjeWEwLpjxrpMO7vJ5xdVjTiI0sDnF1cV5IfbfXAqvqwH/WQ9325y4LDib/5CCvrIruNwuB1obexEw6EmlJaXIB6No+V4OwCgPcqhvf3MP3EiGHzQbtfMY2tn6s2LZMRE/Tw+1Dl+yG/FIY3jAZ7VzUPN5g7t848GLTgaPPPjCQv6eW4yWCY13TG9PPTv7cc6zcOmMDfgBZEJGQnaV+VBaVkJQv4wIqEoPBUlEAURkgQEuoL6F6BQskxGwUmhQBgCLyAUjMDtcaKvN4R4lAcf49Ha2KF/AQoly6T9hj5Z14yOpm4Eevrg7wzC7XEiGo6Bs3JgwKCn3Z/NclIoKZG2oJuOtgIAWk/0v4nDwciAc8aVxrF4UHTA37V4us5NDNvlg++MCxKjAWcjvAQ8cyS/cRgMJNw8vs/QZ1a3OFAfyHgs4oxkdOWlgyPEGPH2TisOB+RRiCk+HvdNCyT76Bn5U5076UhELrlnSoAYtjsbiQgFEDQDw8+3KcwVr6BvnxjEzPJ4Iv3fOzyEoCmUfENXrFBMBRU0xVTkrjGTJSZ44sSM2ok+DkE+9e+hhZEwwUNOz+4zONt4tjLZGyfSh/wWQ/HOJRYRI93ytDUv9U8yFTNFL+i/XdhFxHJ8dWM53m9LfQas0iHiHUUAjSACI1cMzmoZi5W3Lu0gRm/OfbMarZHUe7/nV8TwlwvlAKjmEIvzV9Vks4hZp+gF3RbhoHQriyYLWtBAlPofxGmEPA8JFpKWMEu0KQWDw0dRkSHunZEvQ6HISND+OIOuqCwQdfBKNrhyrXYYph5tEa7o3yq5Yk6G9f6w3X7W3buMBP21Dyr0T6JQ8ggd5aCYCipoiqnIeacw5/bTUuZ5SMhDOXNMOsXPRr2L7b4xWobnp5YvJw7Wvl4Df9zoS91ojY12LJNdPxvXOBvJtN7pdOpz+3y9VhH7rm4l/jb0scfSswLLDrkeJsvG9T87Q3kk5rt3tA1NMRVU0BRTQQVNMRUZtaH/Or8TU8rkwJ+f7ynFSyfk5f5XDgnjZ7PkpVjHAhyu2yDP/NlYCVuvJJ0lZ/2nGqKiXbZ6cTuqHPJ89y2bffhIsWL6lnOCuPkcedXEpjYbbtsqG5VUOwS8q4jlECVg1pvk7NfWK1phU8zqXrOuAif65Fvzy5k9WDpUXnnz92Mu/GpfaSI9zRfDi4qYB3+MwYJ3q6HFxqVthEfF1z8ox94eOfDnnil+fHlUOJF+65QD9+7wJtKjS3isuLgzkY4KA2cGP7mqFQrbQVy2uhLtCqOZJ+d0Y26V7D765CE3nq6TFwnMr4riD3N6EunWMIul78lmNiwkfKJyBp39VjViihnj1y7uwMgSOcDp7k+8eKdJtmf46qgQ7lKYEe3utuLGTaQ1ghEyErTXJqFSsWJFbZZi40Ac746SHQgG5PFklNlE4hyr6jfFZSHLoDYyYRkyj2RGMxV2kVixol6OVWIl83BbyItYWDIPC6PfUSq3iYTRjEVlwlOirpcqT05Vr2RGMxV2kagLqyqWx0reW6f6+bFkGeLqPBj95+dTPT87S+bh4Mg8vFaDwTrqImUybDfMxcOhEEJ7hEWv4niJRSQi5WIi0KB48zGQCFcdADgS4KDsOY9y87AoRNwUYhES5D+U2wSU2+Vi9vEMYRTDMRJGK94QEjBgCdDYEh6MKkRVaWdV4xCIL0pPjEGH4k3nYCUMU4RZChJwTMNTAwDGlPCEwBr7OCIWpsouEIIPxBkiOMjKSkRopyQB9ao8x5WSYbPHghwERfjoEKdAOFF1RlnCjMbFiRjikp8fLwLH+5R5SBinen71AY5YEzrCzRNOVC1hlgj/9VpF4hc4IoCwYTM6bJeHcWgKJX2MCpqqk2IqqKAppoIKmmIqDI1y3D/VTwzJqHnjpEPXhJDInJHw0HTSYem+nR5DRjOXDorgEgNmNiKA+3d6dc9T8qWRIUwri+ufmEN2dVvxsmJINBUendFraGJ6TbOd8G7Wg4GER2eQz++BXR6i46nH3Koolg0daFJ0GhtrLFbEkKC/Mjqsefyw32JI0BwDfGNsiPjb/Ts9hsJdZpTHB1xDC0E0LugFNVFcM/zMNz0frGhwGBb0DWNChhyh2iOsMUEneX4P7fYYWuo1ycMben560CYHxVRQQVNMBRU0xVTkNB56kEPAVEVnqo9ndN3y1SyojsKumJL9pNOKzlh2l9NfOihCzNptarMRs5HFiNsiYp4iDkOQgLUtqbd/U6HCLmCWwrswIjDYaMATBQDmVUXhVsxG7u625tQOIaeCvqAqhidm9yTSdX4Oi1ZrB+2o+e15PUmMZrJ7Q/50QTcRy7HgnSocDRa3oAc7RTw3Tw6IigjAuJXZNdCZ5osTeaRjNPPzmb1EeMOtW3x4/aQza2VUk1NB98QY7OqWI8ga+4wL8UCvFa0RWdCZ7pKUjD09ViLoKZok0KfYiAog7m0ss5iepATiLJFHR8T4l/yQ30rEbvQY3LjIKDkV9PpWB9YbGAZKxg0ZhBKmit6egsVIY8iCqzI04dFje5ct4zy+uzn9PQfTobh/VykUg1BBU0wFFTTFVGi2oZOtgtDCqLtlNuBFY+UUDcQZyHkwhu9FtomnYYQZFZgB21FrYSQGI1sIknGdaaEZ4M8M/2VOJWpnJdRf10L8bcS/BxFrCinFC8tIaLiefH5jVgzSDGDLBlLj3TTAn/LZgAqaYiqooCmmggqaYioymil8dWEHpiuCj36yy4O/HHMn0lcPC+N/zutJpOsDFixRGJWkwkeXt6HaIXeDb/qwnAiQWT4xgO9PDCbSG1rt+K+P5NnFQQ4Bmy6XzVBECRj/GhnzcODqFmJlxOI1VYQNwe/O78GyofLihufr3XhkjyeRnlkWwysLZdOX3jg7wMxGzc6rWlGq8KC4fkMFdnXbEumfTOvFDWPkwPfXTzrxg+2+RHpsCY93F7cn0lGRQe3rg4g86q5pJoKu5r5djTZFYNDz87pwUbW82uexAyX4/SHZQGdhTRTPzu1KpFvCHOa/YywWZ83idowuke0U7tzmw39OybEc3xjThwemyatednTZ8IX3098ZIiNBW1kQQT1qIxOWIY9bDS6nOf0Z5TXU3VtOVQaL+jdHVYZkRjM2nTwsjKTKg6wHo8rDlsL4pV69LOp6Mdp5SknW+dg40jRHnYe6DOrVLSzI40aXQyXLQ60RLgsaUZLRsF2pVST2EAzxDGGWYmMlInRQkED4eqQybOe1isRNCMQZYq89BycRjj9xEUQwDAMJPhtZjW5VgEyZjVR5b4whyuC2iIRZSlQAEV7KMRI8CiMaCfpBOD4bOTjpjzPEOLCLI92cYiLQp6gXy0jwWo3VqyfGEOs1SywiEZQVFhhEBPm4lZFQoshDlEAYCaUybKd+fn08Qxy3sxJhdsNL/UFRWmgN22X0htbLOCYyiMUyG5Ps1ckjonoIaiQw6NYpg1oIavp4Fn0axwVJPw81eoIPCSxCGhMOYgp56tVLbwPTeBr1UqP3/KIig2iGeSihnUKKqaCCppgKKmiKqdBsQ78wr0vr8ABePOrCexmua3t+fndOd1YSAXzzQ2OLBm49J4g5lTH9E3PIR+02PKXwbk6F5+Z15fSNlYJrsC6XDY7g66Oz58uhKehLB6fuSAQAa1uMLaBMhhEXpHRINmynx2Rf3PC9yDb+uHH1XDIoashophAMdwlZvbe0yUExFVTQFFNBBU0xFTld9T2mhMfCGrl91BNjsaIxd54M6XLjmD6irfnvBqfmTgX7eizY0mE74/FsMLcqhkle/ozHvVYR14+Q40sECXjxqPuM5xeK60eEiBnN9a123e06MiGngp5WFscjCrvVOj9XlIJ+cJqfmGbe0GrXFPTmDhse3GXMwdQoj87o1RR0lUMk7m1EKE5B3zkxOMBo5qwVdHOYxdun7Ip07iygMuHdZgesiuCfvhyY2WSbIM8Q9zZegPWAqbCxzY46vyzoXGsgp4Le0mHHFgN+0YXili35NUPJBi1hDt/enHsTnky5z6AXd6bQTiHFVFBBU0wFFTTFVGi2odt13CZ9NnHAVsVGkFLII9sUwgynUHRE2AErRIqduKgfKz5U45imoGfqrIt7Y1EHZpanvztUTGR08zArnIUjgnt4XoCh3ZJS4Ny3zr57u7vbimt03GC1blNORzkoyWE5FvOXnQuLTb79H6/di94Ov8anKKlA29B5hmEYXHTN+WAtHE6v55RECbMWTUZZTX6HuMwIFXSeWfj5OeAsHD54bStikTgkScLWd3fC3xXEjAW1qBxy9o2JFxM5FrSU5D+9c4xeI9PjqZ6TfTa9sR2RUH+sy8fv7UF3W6+BT2ejXrm4d+mUM3vktA197fAIfn9+TyJ9JGDBotWy0YydlXDkWnIZ/MhXSRuD7Ve2ocYhR+V/7YNyvK8wmvk/k4JYPkk2mlnXaseNim0sBjlFbLtCNpoRJGDUCtJopv7aFsKmYOG7VTiaw3gDALjo2tmJf89eOiPx75N1zSl9flypgHWXKY1mBm4adOL6FuKNdd5b1cQOVH+7sAsLquWVOL/ZX4LHDspGM4tqonhxvmLToDCL2YpNg1hGwgmVDcXYlaSNwYYl7RhTIk9937b1LN40CFAv0xn47UxlGY/mOYz+NZTHz2SXnI3lREZgspBhxvfO4PFkp2ajDNkkp4J+rdGBN0/K9lRqLUVFYPSrpH2VeoXUBauqiRvJqy7yP/tL8PgBea2d+vMtYXZAHmomvkYej+ehxbHxta2Ys3QGbA4btr67ExPPGwtPean+Bz/lSIDTrde4Fdr1unFTOfEGV4/Rr2+1E3mob4soDXx+6jwuXV2l+fyyTU4FLYHREYfecRAuSckQwUDMMI98R6ptXLkVs5fOwPY1exCLxjHrkimo23EMY6aMMHCVzOslSAy0zPPz8fyyDR2HLgAXXDkLFiuH85dMByQJFqsF0y+aBJZj0Xi40KU7u6HDdgVg2+pdEHgBO9bthSCI2LFhH8LBCPZvrkNXS0+hi3dWQ9/QeWbzqh2I9EWx9d1diISi+HjtHsTCcez+4CDiMR6idvuJooOmoN+9tF3rMOH7mw42VsJ/FnVkdA2jiAAuN+hRnU0ifVHi/9FQ/7BZNJx9I5u3L23P+0/wVWsrM+qTTPLyurrTQlPQtb7MBKsHk4c81KRjNHO2MsnL591ohmGQ0dyJyyJlpAnahqaYCipoiqmggqaYipyOckzxxQkzlI4Iiz8eNuagmQ/uneIn9mb5w0E3OmPFablwmkq7gFsnyPsK8CLws70ejU8UhtsmBFFhlzsu/z7hxL5ea87yy6mgx5Xy+O54+abX+bmiFPS3xvURRjN/PepCZ2Hdc3Xx2STi3kaE4hT0F0eGCKOZnV3Ws1fQ9QEL/lznSqQ7osX51nuu3g1OEbXUq2NdO9nL49vjtHZdyZxar/bStp4YQ9zbfE8xp8q/TrhQYZcFXR/I7dRHTq++p8eKPT3Fvwrj0T3G3mwXVMVwQVVhX+EdUQ4P7S7+e/vEofz+ItNOIcVUUEFTTAUVNMVUaLahD/Qaa2LrbfSYCkbzNEo6sT9NIS7n5dLjVBqunQd7LTk3mtGy/E2F7hhr+N5qGc1ktDVypqSyNTKleElla+RcoLU1Mm1yUEwFFTTFVFBBU0xFRj0djpGI1q4ogWj/MpAGxOMW64xWvrGo/BQEqX9R6mlYRiLeNhL6F7Uq/2LRubfqPPpXXMvn5OP56dYTEtFxHVhPY2Qk6JUXdxLuo/+9w4MXFBvXXDM8gidm9yTSdX4Oi1ZXZ5Kladi1rBVem/ywl62twM5ueWeth6f7cdNYecvgFQ0OfH+bbBM2rlTA+iXyyo6IMNBopv66FkKQ575JGs28OL8LC2vkGc9f7yONZi6uieIvFyqMZkIszl9lzNH0vcvaB2wapDSauWlsCA8rNj/6uFPffVSLjMei9PblJo/Tt7MSrXsnqY4nNdlKYQzK2PPRPp7ukFemZTBChsN2yQ4zKR//bA/bZXbv9I/nPo/Uhu2yUQ8SrWG7DN/QesL7LAgzXTK9d6nc23zkkesyGIOOclBMBRU0xVRQQVNMhWYb+pOrWg1d7Bd7S/HyCZf+iVnke+cEiaVIm9psxPBWKmy+vA02Tu6cXLe+Aif6Uu9eTPXF8ILCRzkdbvig3NDSpFElPF5d2JlIRwUGc982NiT6h9ndmKtYqPDkYTf+XJffgPyvjArhx5MDWbue5lOrdhhzZXFyWRx/SRG3RSLKqRzbTZUqh0CsKTRqzmJljd+rgdcwVm4LQ+YZ0bIRPQM+m0hcw12A5+fkpIzvnRLa5KCYCipoiqmggqaYipwuw5hTGcW3x8nxCM1hFg/syu9KZZ9NxP+bJe8uJUrAzVuMdRq/Mz6I2RVyzMqaFjteOp7bzu9XRoVwyaBoIr2lw4Znjrg1PjGQp+d0E4E/P/rYi9546u+wKb447pwob8jUHWNw1yc+Q2V4ZHovBjnlNvKf6tzY1mnT+ERm5FTQg50irhgaSaTr/Pn35XBwElGGdNxHZ5TFiWs0hXP/wzbZR+bZ3+kzJujLh0aIDu59Oz3oNbCTdZVdIMrQHDJe7wU1USI46Y2TDsPXMEJOBb2724r7d8qeF3qbkucCf4whypBOIMzLJ5zYrnir7M/D+sL/nHQQpizpGLQ8uMtDTCz7dQx01BwOWIh718cbn6b+3cESeK3yTd/TkzvXJCDHgj4atOR8vz89QgKL5+qNvdnUbGh1YIOxIfmM2dxhx+YOu/6JGjyfYb1PhSx4rj6z5/dqQ37nJWinkGIqqKAppoIKmmIqDDWQNrTawGt4LpwMGRvFECXgvWaynXg27gHVG2cH1MMofgPDacWClOT5GTXyaQxxmvfOwkrEMjE9DAn6li1lWb3xcYnBNz4s1z+xyKkPWExRD6NIyPz5rW52YHXzmYfyvFYR+65OvUd+9r0WKBQNqKAppoIKmmIqTL81MgMJPlWMtFGXVLdFhE3xkajQP2FzGo6R4FHMhknQnxX12cQBs3hKgxUXJxIx2jER6OONlbvMRs7z98QYSAYWpVoZCSWKeokSVLEgEsoG3FsGyoWvXqtIxJP08QyxKtzOSnBZ5GvwEhDIoJ9mekHXOEVsv7ItkRZEYOSKwRqfGMivZvXimuFyTMP/HnHhQUWQ1fSyOF5fJK8e6YkxmPLGIM1rblrapmk0c+/UgKbRTCrsXNaqaTSjx4XV2kYzLAPs+RzZYeu3MZDTry/q0DSa+droIjOaKXqk/jfqadLxh+YlhriGeuhSUuWRip1sXGQQFci3OpGnSF4zHQu1mIABNltGEMFo10tV72T011NxTVUhBNU14hla8Zpe0C0RDmNXGnsjq7lzmw93bvOd8fiObpvhPGa8qW2p9ZPdXvwkw02Bxr+WWb03tNo16yWC0a334jVVmsdfOOom7OMyhXYKKaaCCppiKqigKaaCCppiKnLaKbx6WBi/Pa/H0GfOWZl/99Ha18khtpjBZVozymJ4RWH6kg7Xb6jA7u7U19rVBziMXaE9NJhrWEg4fG2L/okK7tzmw5unnPonpklOBc0ygKM4t/cmiGY4VJSNehr9qZTAIJo9f5b0SKPeRk18jEKbHBRTQQVNMRVU0BRTUfQzhS/O70K5Isjm/l0e7OhKvfNUYRfwwjw5HkEEcPU6MlZgxcIOWBVf7Zs3lxFbEd812Y8F1fKqiddPOvCnHLt0fu+cIJYpPDHWt9rx6/3yhj7DXDyemtOTSMdE4PoNZL3eWNRBdK9v3FSGrljqjd5zy2N4aLocZ9ERZXFTlhcyXD0sTLjHHvRb8KOPfWlfr+gFXeuNE847pRZjEQlWFpih2KkrmdHMtLI4EdlmV7lwjnQLxDU+7sqttwQADHOReR4LkkJ0cGS9krmPTiuLE50wq8HfY49VJPJIx2hGj0o7mYeQ4Rq8ohf0D7f7CIHt6zFW5O4oi29+KEepJTOa+e7mMiKIp0XljPTU4RKsbJSHmhr6cj9089ejLmxoldfaNas2r28Ok/VKJoRvfVgGRlEvo0Y/u3usRB4RIftDFKub7WgMyXlkakZU9IJ+vy2zxadRkdFcswYA77VoH9/TY82544+ag34rDvrPnGcfz+rWa41OvfTojHJY3ZzbL29jyILGUPZkSDuFFFNBBU0xFVTQFFNR9G1oowx1CbhxTJ/+iRqsbHQWvfGL1yrimuHhjK4xxFnoufPsYzpBT/Dw+NlMv/6JGnzQZi96QVc5xIzraUaK+6lRKAahgqaYCipoiqko+jb0KDcPi+Jr1xRiCZOXfFDjEFCqMFzpiTHoiOZ2wqHKLhC+HYE4Y8hTIxu4OBFDXHLHkReB4wZ22C0ExV06AK8s7CRiOb66sTzj2UOjPDDNr2k0kwvunBTM2GgmU+ZUxjSNZoqRohd0d4wl3tDxAow0BeMMOqJyIYxacqWVJ0/mGchDnmpiIlmGrgJs+mSUohf0ZTpGJfng7h0+3L0jv3n+Yq8Hv9jr0T8xh2xqt2PGf4r7jaym+L9yFIoBqKAppoIKmmIqct6GTmfn1mLMoxjL8FmttxaMpFGiU8uXEwdrX68xGOOQTm2NrorI1R01Uo5slaEQeWZSBiDXzzjZpkFDH3vsjBfI6A39l1t2YMrwAADg/n9NwJRhAbT5bXh2w4hPz8iHA1J+XZaSU4gyFEO9geIpRz9pt6Ff+v7HmDoiAI4FOBZ4+AuHcaLDCbdDwO1LjmWzjBRKyqQl6FeXb8eEwX1gGeD25yfj/YPlsHASfnjVUYSjHJ5+b2S2y0mhpERaTQ6nTQD76VchEmchiAx+unIcJg0NQgIQz3OsBYVymqyNcty+5DisnIg/rxuR+NvlQyJ4ZHpvIn2sz4Ivvl+RSNtYCZuWthHXmb2qmtip6a1LOlBll00nbt/mw5YOOZbju+OD+M44eYXKpnY7lm/3JdLVDgFvLupIpAUwuGBVNZHnpqVtsLFy5+bz71egQRGE89MZvVgyWI7l+OdxF35zQDZ9meqL49m5XYm0P87iUp0ZzrWXtaPUIs/j3/RhOfb1yqu8f1wbwBdHyrEcq5oceEARPzLKzeNfC2TH06jI4MJ3yHptuaKV+Am+Ym0lEVT1+/O7cUGlbKDzdF0Jnjkibw8xryqK3yncY1sjHJYpTHpYSNhyBfn85r1TTeyT8sqCTox084n0vTu9xGr1L40M4Ue1gUR6d48V3/oofTObrAna6+IH/M3BSRisiNYK8qQbCgMQx0//TdlvrnYIRHCSXfXyL7WSeZTbyeuxDJlHMqOZQU6BMJqxqPo5PptIXMOj2i7NypJlcMb0e/41DjKazsqSn/Gq8lRv0WZhyXolM5oZ5BQJoxm182e5ncyjxELmYVfVawBM8uenpMohkPdGZeLjtpB5NIUzC9ZJa9hu1V1bMLQ8intfmoC1+ypQVhLHPZ+rx8JJXfjd26Pwv+v739Jeq4hhLvlOR0XgSEB+CzGQUOslvwj7ei1Q3pYJnjghsBN9HIKKQJ1qh4Aqu/JLw+CE4u1qYSRM8KjzIP0uar1x4kHUBSzEjk/DXDy8ivDRzhiLFoXxi5MTMaZErqcgQdNTAwAmeeKEuc3RIIewoqk22CkQFmi9cQYnFf4VNlbC+FK5XhKA/ap6TfbGifQhv4XYTWukm0eJwomqLcKiXfEGL7GIGOmW68VLwCGiXhImq57f/l4L8Qs7vjRO7PF4MsQRex2W2wQMVrywwgKDo0G5nnkdtjve7kQoZkGoywJ/eOCleuMsenvP3J6WwAwQl5pDOsJoi3Bo04gT5iX9PNRCUHMyZMFJjeNhgcU+jXom44BOvZrD3AC3JCUxUb9eesdP6MQ2B3m9eumXoS6gfbwrxhny29Mjrd7bA69MQFP3wJjklzYPxru7Cx8dR/nsktYbettRHx5ZMR6NnbLf2z8/GoLekBWNXbnbboBC0SPtJsemw2RPdE9jYWN3KRTAoKB/VBvIeD8SCsUIdtZYrIghQf/XuJD+SRRKAdEU9LM7d+apGBRK6tyvcUxzHJphGMOxgePGuQAAkYiISZP6Z526u+PYvj1126rFiyuwbl0nLrqoHFZrfxPno496EAwmmT2g5AUXy2Kez4e4JGFDd7f+B3KIJEm5GYdWM2aME/Pn+wAw8PksWLasfyr22LEQYrEG7N4d0Pw8AMyb58O8eT5EoyJ+8INRcLn6xyiffLIBq1d3IBCgos43LpbFXJ8P83w+xEQRMVGEBGBzb6/uZ/NNVqOIpk4tBc9LKC+3Ys4cH44f729zjx7tws03D0/pGldfXY0332zH3XePwZEjfYh9uq3rLbeMwPDhdEiwEHgtFswoLcXDR4/ij42NuG/MGHyuqgozSkv1P5xnsiro115rg93O4sorq3DoUB82bOhGIMDjwIEgjh1Lzfr1nnsO45FHxsNuZ/H3vzcjHBZw5EgIBw4EEQ7Tt3MhaI7F8JsTJxLpmCjigfp63Dh4MCa4XAUs2UBytqZw+vRSDBvmwNq1nXjiiYa0rnHrrSPQ3c3jgQfq0N4e0/8AJadwAIY4HGiIRCBIEh5vaMCTkybhih15Ni3RIGeCXr++C48/fkL/RA2WLz9A28xFAgdgrMuF24YPxx0HD2KY3Y6namshSRJqbDa0xorjhZOzSHy7nUVZmQVud/qBJz6fFWVllsRiAkrhqLHb8eDYsbjvyBHYWRbPTJ4MAGAYBi9MmVLg0snkTCpLllTi5Zdn4t57x8Dt5uB0Gs/q2Wen4uWXZ2L0aBfcbo4Ku0Cw6B/pqLBa8efaWrg5DkG+P2xUkqTEv4uBrEuE5yXwvBzfOnu2DytXzsKvfz0x5WtEoyKU4+NPPTUZK1fOwsSJud2OmJKcUU4nnpg0CQDgsVjwTG0tvrxnDyKiiKgo4gu7dxe4hDJZF/SLLzbhuedOQVBsbSpJEpHW47rrdqCvTyBELQgStCaBKLlDAiB8eu8lSQL/6X/X7tyJa3ftKmzhVGR9pvA0V11VhTvu6F/9vWtXAHfddcjwNV55ZQZKS/v7rd/73r6Uh/4o2WeI3Y5namvh53l8ec+egpZFa6YwZ4KmUHKFlqBpN4tiKqigKaaCCppiKop+Swqzsernc8GxDK57cAv6FGYav799KiYMJ4N9nn37BP657lQifcGkMjx8U//wWUNbGN/+TfFMORcL9A2dR9549AK4HRycdo7wAfnjndMxdbQXD//lED7/0FZ8/qGtWPNJO751xUjctKTf4+Ti6ZV46BsTUXcqiHue2YdxQ9147sczC1ORIoYKOo/c9KtPkjo3lTg4cByDUJRHINT/X4wX4bBxsNtYLD2vGnd9aRz2Hg/g/z6zH7vr/bjzD7sxejAVtRra5Mgjnf4Y+qcpyFGnHz29F7/8zmTc//UJiH66b53XbcU/1p7EP9aexKIZlSh1WRHjRfhD/dPMPcE4OJZBuceW51oUN1TQRcAd143FsConfvfqURxo6F/V8/XFw7HkvGq09UQR5zPze/ssQZscRcCIaiccNg5NnWHUN/WhvqkPvX1xVHhsqPDYsHFPJ559+wSmjfFg+efHYkS1Ew/cMBGd/hju/tO+Qhe/qKAzhXni1zf3h1vOnlgGhmGw/XA3BEHCT/92GMOqnLjnK+PRFxHg7+s3WBw5yIVPDvfgr2tOorE9jAqPDVfPHYQvLBiChrYwqnw2PPrXw9hZX3zr+nINnfouAi6eXpn075v3dyESFzFrvBceF2lseLS5Dw1tcvxKhceGqaP7HapCUR5bD/bkrLzFDBU0xVTQWA7KZwYqaIqpoIKmmAoqaIqpoIKmmAoqaIqpoIKmmAoqaIqpoIKmmAoqaIqpoIKmmAoqaIqpoIKmmAoqaIqpoIKmmArNeGgK5WyDvqEppoIKmmIqqKAppoIKmmIqqKAppoIKmmIq/j+hLnQPRC5vJAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 360x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_environment(env)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `step()` function actually returns several important objects:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "obs, reward, done, info = env.step(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The observation tells the agent what the environment looks like, as discussed earlier. This is a 210x160 RGB image:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(210, 160, 3)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "obs.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The environment also tells the agent how much reward it got during the last step:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reward"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When the game is over, the environment returns `done=True`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "done"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, `info` is an environment-specific dictionary that can provide some extra information about the internal state of the environment. This is useful for debugging, but your agent should not use this information for learning (it would be cheating)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'ale.lives': 3}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "info"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's play one full game (with 3 lives), by moving in random directions for 10 steps at a time, recording each frame:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "frames = []\n",
    "\n",
    "n_max_steps = 1000\n",
    "n_change_steps = 10\n",
    "\n",
    "env.seed(42)\n",
    "obs = env.reset()\n",
    "for step in range(n_max_steps):\n",
    "    img = env.render(mode=\"rgb_array\")\n",
    "    frames.append(img)\n",
    "    if step % n_change_steps == 0:\n",
    "        action = env.action_space.sample() # play randomly\n",
    "    obs, reward, done, info = env.step(action)\n",
    "    if done:\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now show the animation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_scene(num, frames, patch):\n",
    "    patch.set_data(frames[num])\n",
    "    return patch,\n",
    "\n",
    "def plot_animation(frames, repeat=False, interval=40):\n",
    "    fig = plt.figure()\n",
    "    patch = plt.imshow(frames[0])\n",
    "    plt.axis('off')\n",
    "    anim = animation.FuncAnimation(\n",
    "        fig, update_scene, fargs=(frames, patch),\n",
    "        frames=len(frames), repeat=repeat, interval=interval)\n",
    "    plt.close()\n",
    "    return anim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_animation(frames)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once you have finished playing with an environment, you should close it to free up resources:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To code our first learning agent, we will be using a simpler environment: the Cart-Pole. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A simple environment: the Cart-Pole"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Cart-Pole is a very simple environment composed of a cart that can move left or right, and pole placed vertically on top of it. The agent must move the cart left or right to keep the pole upright."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make(\"CartPole-v0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.seed(42)\n",
    "obs = env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.01258566, -0.00156614,  0.04207708, -0.00180545])"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "obs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The observation is a 1D NumPy array composed of 4 floats: they represent the cart's horizontal position, its velocity, the angle of the pole (0 = vertical), and the angular velocity. Let's render the environment..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASUAAADICAYAAACuyvefAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAF4klEQVR4nO3dzY8bdx3H8e/Y3sfulhSiKlSoIr2UWy9BKDlHAokD/wF3cuN/yYkLh4gr/wRCWiEFISHoAVVERXS3PKQJWbRee8fDgQjqbHcDnk39sfN63Wb8oK+0o7e8Y89vmq7rCiDFYNkDAHyeKAFRRAmIIkpAFFECoogSEGX0ksf9XgB4FZqLHvBJCYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBElIIooAVFECYgiSkAUUQKiiBIQRZSAKKIERBktewBWX9d1Vd3sc3uaqqappmmWNhOrS5TobfzkqD7+xc+qex6m4eZ23fjgu7X/zvtLnoxVJEr0NpuO6/jTj+Y+Le2/8y1RYiHOKdFbOx2f2zfc2F7CJKwDUaK38WdH8+eUmqa23/r68gZipYkSV6B7YdsJbhYnSkAUUQKiiBK9dF33798pwRURJXo7+fuf5raHG1u1+ca15QzDyhMlepudTee2m+GoBn4SwIJECYgiSkAUUQKiiBL9dF21k5O5Xc1gaIUAFiZK9DJrp3X67K9z+7b2r9dwc2dJE7HqRImr1wzKpSYsSpSAKKIERBElIIoo0cvZ+FmdnRzP7du+dqPKt28sSJToZXY2rVk7f5nJaGvXTwJYmCgBUUQJiCJKQBRRopeundaLa3QPRpvLGYa1IEr0Mn5yVF17Nrdv56vfWNI0rANRopcvXArXN2/0IEpAFFECoogSPbmTCVdLlOjl5PGf57ab4UZt7H5lSdOwDkSJXtrJeG67GQxrtL23pGlYB6IERBElIIooAVFEiYV1XVez6encvqYZWLaEXkSJhXWztsZPDuf2be695UQ3vYgSV6sZPL+bCSzG0QNEESUgiigBUUSJhbWTk5qePJvbt/Xm9WoGDisW5+hhYV07rdl0/jKT0fZeNU5004OjB4giSkAUUQKiiBIL62btuTW6m8FoSdOwLkSJhY2f/uXctW+7199d0jSsC1Ficd2sXlwO1zdv9OUIAqKIEhBFlIAozRfe4fS/3D/nNdS2bT148KCOjo4ufd77107r3cEn/9k+a2f1y0/3a7L99qWvu3v3bt26detKZmVlXbgSoO9vOadt27p//349fPjw0uf96Affrh9+7zv12eTt2hyc1nYd1f2f/LT+ePjk0tft7OyIEhcSJRZ22u7WwePv19Pp12rYtPXNrYPqup8veyxWnHNKLOyT8Xv1dHq9qgbVdhv10T8/qMlsZ9ljseJEiYV1L5wWmHXDarvhkqZhXYgSC7v55se1P3pcVV0N6qzee+M3tT08XvZYrLhLzynNZrMvaw6C/K9/91/99nf14aMf1+PJjdocjGu3Duvwb/946eu6rnNsveYGlywEeGmUDg4OrnwY8k2n0zo+fvknnl//4bCqDqvq9//X+z969Mix9Zq7c+fOhY9dGqXLXsj6mkwmtbf36u7ddvPmTccWF3JOCYgiSkAUUQKiiBIQRZSAKK5945zhcFj37t2rw8PDV/L+t2/ffiXvy3qwdAmwDBcuXeLfNyCKKAFRRAmIIkpAFFECoogSEEWUgCiiBEQRJSCKKAFRRAmIIkpAFFECoogSEEWUgCiiBEQRJSCKKAFRRAmIIkpAFFECoogSEEWUgCiiBEQRJSCKKAFRRAmIIkpAFFECoogSEEWUgCiiBEQRJSCKKAFRRAmIIkpAFFECoogSEGX0ksebL2UKgOd8UgKiiBIQRZSAKKIERBElIIooAVH+Bc6iAHz8aqUYAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 360x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_environment(env)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's look at the action space:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Discrete(2)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.action_space"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Yep, just two possible actions: accelerate towards the left or towards the right. Let's push the cart left until the pole falls:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.seed(42)\n",
    "obs = env.reset()\n",
    "while True:\n",
    "    obs, reward, done, info = env.step(0)\n",
    "    if done:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving figure cart_pole_plot\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAADvCAYAAADM8A71AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAJwklEQVR4nO3dTW9b95XA4UNSfJFIybJlWbabxOk0BYJOi+lmMNtBELSfYr5FPsF8j1kNsu6iiwDZJUAX00wHaNGZtqk7sRu7lh3rzRJFii+6XRRt4VzalmxZh1d5nuWhbJ0F8YNwyf+9taIoAoDzV89eAOCbSoABkggwQBIBBkgiwABJFl7wuq9IALy62qyhv4ABkggwQBIBBkgiwABJBBggiQADJBFggCQCDJBEgAGSCDBAEgEGSCLAAEkEGCCJAAMkEWCAJAIMkESAAZIIMEASAQZIIsAASQQYIIkAAyQRYIAkAgyQRIABkggwQBIBBkgiwABJBBggiQADJBFggCQCDJBEgAGSCDBAEgEGSCLAAEkEGCCJAAMkEWCAJAIMkESAAZIIMEASAQZIIsAASQQYIIkAAyQRYIAkAgyQRIABkggwQBIBBkgiwABJBBggiQADJBFggCQCDJBEgAGSCDBAEgEGSCLAAEkEGCCJAAMkEWCAJAIMkESAAZIIMEASAQZIIsAASQQYIIkAAyQRYIAkAgyQRIABkggwQBIBBkgiwABJBBggiQADJBFggCQCDJBEgAGSCDBAEgEGSCLAAEkEGCCJAAMkEWCAJAIMkESAAZIIMEASAQZIIsAASQQYIIkAAyQRYIAkAgyQRIABkixkLwBnoSiKOPzqThxu3YvutbejvXw16gutqNUb2avBMwkwF8bmLz+O3S/+J2qNZrR6V6K9sh7d9Vux8U8/jkaznb0elLgEwYVwPD6K4e6DiIgopuM42nsYT778dWzf/nlEcZy8HcwmwFwIR/uPY7S/XZovXvlW1JudhI3gxQSYC2G0/ziOJ6PSfGn97fNfBk5IgKm8oiii/+iLiCiemtfqjeht/EPUarWcxeAFBJjqK46j/9Xd0rjebEert5awEJyMAFN548F+DHcelOadSxvR7K6e/0JwQgJM5Y0P92Jy1C/Nl9bf9j1g5poAU2l/vf5bTMel15auvuX6L3NNgKm8/sM/lGb1Zju667cStoGTE2Aq7XgyisHOn0rzhXY3mkur578QnIIAU2mjg+0Y7W+V5p3LN6PRXkrYCE5OgKm00f5WTMfD0tzlB6pAgKmsvx3AKJ4+gBG1enQdwKACBJjqKoo4fPzH0ri+0Ir28nrCQnA6AkxlTY76Mdi+V5q3V9aj1bucsBGcjgBTWZPBk5gMD0rz7rW3o9ZoJmwEpyPAVFb/0Z3Zd0Bbe9P1XypBgKmsg83bpVl9oeUWlFSGAFNJzzqA0WgtRqt3JWEjOD0BppJG/d04evJVad65fCMW2t2EjeD0BJhKGh1sx3Q0KM2X1t6McP2XihBgKufvBzC+9rDNWj1619/xARyVIcBU0DMOYDSa0V65lrAPvBwBpnKmo0EMtsoHMFrLa9Fa9ggiqkOAqZzJYD/Gg73SvLt+K+oLrYSN4OUIMJVz+PiPcTw+Ks0X195w/ZdKEWAqZ//B70uzWmMhute+nbANvDwBplKOp+OZT0ButBajvXw1YSN4eQJMpYwPn8Rwb7M071y6HgudXsJG8PIEmEoZ93dienRYmi+uvRFR83amWrxjqZT+oy+iOJ4+PazVonf9Oz6Ao3IEmMooiiIOH39ZmtfqC9FZvZ6wEbwaAaYyjsfDmSfgWr0rHkFEJQkwlTEZHsT4cLc0X7r6VtSb7fNfCF6RAFMZh4+/fMYd0BzAoJoEmMo42JxxAKPecACDyhJgKuF4OonBjAMY9WbHHdCoLAGmEibDgxjOeARR59JGLCwuJ2wEr06AqYRxf2fmI+gXr9yMWr2RsBG8OgFm7hVFEQcP/798ACNqnoBBpQkwlfCsD+A6qzcStoGzIcDMvePJaOYTkJvd1Whf8gEc1SXAzL3pUT9GBzul+dLVN6PR6iRsBGdDgJl7h1v3Zt8B7fLNiHD9l+oSYObewebtiCieHtbq0d1wBzSqTYCZa8XxdOb3fxvNdnRWNxI2grMjwMy1yVE/Btv3S/P2yno0Fy8lbARnR4CZa+P+XowHT0rzzuUbUWs4gEG1CTBzqyiKvzwBYzopvdbbeCdqHkFExXkHM9eedQBj8crNhG3gbAkwc6uYjmO496g0X1hc8QgiLgQBZm5NR4MY7W+V5ktrb0SjtZSwEZwtAWZuDbbvx+SofAe0zuUbEb7/ywUgwMytg83bEUX5AEbPAQwuCAFmLhXHx7OfgLHQjM5lH8BxMQgwc2k6GsRg68vSvL18NZpLDmBwMQgwc2l8uBej/m5p3r60EfVG8/wXgtdAgJk7RVHE4eO7UUzHpdd619+JWt3blovBO5m5tP/g8/KwVo+ltTfOfxl4TQSYuVMcT+Jo1gGMTs8BDC6UhewFuPiGw2EcHJS/z/sstekwhjMeQdTorcfuwTBqh+VLE0/9XKMRq6urvqrG3KsVX/+e5dOe+yKcxIcffhgffPDBiX/+h9/ZiH//t3+J+tf6+ZOf/T7+4+PfvPDfv/vuu/HRRx9Fp+NxRcyNmX8N+AuY124wGMTm5uaJf/76D2/E1uhmPDq6FZ16PzY6d2OxvhP/9b93T/T/rK2txQv+sIC5IMDMlVqtFq3Vf45f7PwoJsVfvm529/B78YPuT+P2/e3k7eBsCTBzpdNqR+fqv8akaP1tdjhdjv++dz0e7pz8OjJUgW9BMFcuLa/EldW1r01r8asvdmI4Kt+YHapMgJkr37+1HBvd3Xj689/j2N66XbovD1SdADNXnvSH8eDz/4zdR7+I4fAgWrEX3176Zew9+ix7NThzz70GPB4///uWcBLT6fTEP/vz396Pz353PxqNj6PXuxrfvdmLf3yrG7fvPTzV75xMJt6/zI1mc/b9S54b4E8++eS1LMM3y+efzzhW/BxFETGZjGJ390/x2W7EZ/93ut/X7/fj008/jXa7fbp/CK/J+++/P3P+3AA/6x/Bady5c+dcf1+324333nsvFhcXz/X3wmm5BgyQRIABkggwQBIBBkgiwABJBBggiZvx8NotLS3FzZvn9yj5a9euuRk7leCG7Lx2R0dH0e/3z+33NRqNWFlZEWHmycw3owADvH4zA+waMEASAQZIIsAASQQYIIkAAyQRYIAkAgyQRIABkggwQBIBBkgiwABJBBggiQADJBFggCQCDJBEgAGSCDBAEgEGSCLAAEkEGCCJAAMkEWCAJAIMkESAAZIIMEASAQZIIsAASQQYIIkAAyQRYIAkAgyQRIABkggwQBIBBkgiwABJBBggiQADJBFggCQCDJBEgAGSCDBAEgEGSCLAAEkEGCCJAAMkEWCAJAsveL12LlsAfAP5CxggiQADJBFggCQCDJBEgAGSCDBAkj8Db3YivUWpwvQAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 360x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_environment(env)\n",
    "save_fig(\"cart_pole_plot\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(210, 160, 3)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that the game is over when the pole tilts too much, not when it actually falls. Now let's reset the environment and push the cart to right instead:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "obs = env.reset()\n",
    "while True:\n",
    "    obs, reward, done, info = env.step(1)\n",
    "    if done:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASUAAADICAYAAACuyvefAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAH9ElEQVR4nO3dzW5U9xnA4feMZ/wxNhiwwQnNh/JRFt20WaSq0m2lrnsFVTZdddsLyDX0CrKs1H0uobRSm0ipqogEkoZAANsEbOzx2J45XSRNOICN54Ocd4bn2fk/9uhdWD/OHB1eF2VZBkAWjboHAHiYKAGpiBKQiigBqYgSkIooAak0n/K65wWAZ6E46gVXSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqosRQyn4vegfdKMt+3aMwZZp1D8DkOOhsx+76F7H99afRuXsjDnbvx+u/+UPML1+oezSmiChxYnc/+0d8dfmvEd9fHRWxu/5fUWKsfHzjxNorLz9yUsbu5vUoy7KWeZhOosSJzZ9Zi+bcYuWsc/erh66cYHSixInNzC3G3OnVylln80YcdndrmohpJEqcWNGYiYVHPsId7j2I/Qd3a5qIaSRKnFhRFLH0wpuVs7J/GJ27N2qaiGkkSgxk4eyLUcy0Kmc7t6+62c3YiBIDaS2ejeb8UuVsd+N6lP3DmiZi2ogSA2nOLUZ75aXKWXd7Iw73dmqaiGkjSgykaDRi4ZEo9fY70b1/p6aJmDaixMDaK69ERPHDQdn3ECVjI0oMbGHlpZiZna+cPbj1aU3TMG1EiYG12svRWjxTOdu7dyv6h/v1DMRUESUG1mjOxsLZi5Wzg517cdjZqmkipokoMZTFC69Xvu7t78Xuxpc1TcM0ESUGVhRFtM+/GlE8/OtTxu7mV252MzJRYihzp89Hc65dOetsXo8IUWI0osRQmvOLMXvqkY0Bd29Ez8YARiRKDKVoNGPxwmuVs4POlo0BjEyUGEpRFNFefaVyVvZsDGB0osTQvt0YUF3zvnPncze7GYkoMbS55QvRap+pnO2sfxFlv1fPQEwFUWJoM635x/6Syf72ZhzuPahpIqaBKDG0b9fjPmFjwNZ6TRMxDUSJkSy98GZE8cPGgLLfi53bV2uciEknSoxk/syL0WjOVc52N6wxYXiixEha7eVotZcrZ3v3vrYxgKGJEiNpNGdj4Vx1Y0B3e8NDlAxNlBhZe/XVytf9g/3o3r9d0zRMOlFiJEVRxNLaG0/YGOC+EsMRJUY2d3r1sY0BD25djSj7NU3EJBMlRtacX4rZpXOVs+7929E72KtpIiaZKDGyYqZlYwBjI0qMrCiKWFh5uXJW9g5jd+N6TRMxyUSJsVg4dzGKxkzlbOf2VTe7GZgoMRbzy2uPPUTZuXvDxgAGJkqMxczsQsydrm4M6G5vRq+7U9NETCpRYjyKRrRXq/eVet2d6G5v1DQQk0qUGIuiKL77W3CPbgy4Vt9QTCRRYmwWVn4SjeZs5czGAAYlSoxNq70crcUzlbPONzej7B3UMxATSZQYm0ZzNhbOPrIxYGs99nfu1TMQE0mUGKMill54o3LSP+jG3r1bNc3DJBIlxub7J7sf3Riw8aX7SpyYKDFWc6fPx8zsQuVs587nEaLECYkSY9VqL8f88lrlbO+bmzYGcGKixFgVjZnH1uMedLZtDODERImxevLGgINv/x+cj3CcgCgxdosXXotiplU5e3Drs5qmYdKIEmM3d2olmvOLlTMbAzip4imX1K63ecz6+nq8//77cXh4+MTXiyjj5+0bcX7+oSe5W4txeetibHee/DPff1urFe+++26cO3fu2O9j4hVHviBKDOrjjz+Ot99+O7rd7pHf88ff/TJ+/9tfxNbhSnR6S9FurMef/vyX+OeVr49973a7HR9++GFcunRp3GOTy5FRav6YU/D8+Pfnd+J656fxn61fx2HZioWZ7biw9q+Ip0QJ3FPimbh6cys++eZSHJazEVFEp3cqllZ/VfdYTABR4pnYuL8T6/c6lbOfvXoh5mddnHM8UeKZ6O7vR7n1t5gpDqIsy2j07sfFhc+iPdd6+g/zXDv2n61+31845XEneQiyX5Zx5ZMP4qUHX8bfr+zExvqnsX3/ZmztHn1z/OH397s33RqNo6+Hjo3S5cuXxz4Mk+/atWsnisYHl69EGVcGeu9+vx8fffRRbG5uDjseE+Cdd9458rVjo3TcD/L8OnXq1LH/0v3fMM+TNBqNeOuttzwS8BxzTwlIRZSAVEQJSEWUgFRECUjF47UMbG1tLd57770jtwSMotVqxerq6tjfl8lhSwBQhyO3BPj4BqQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqogSkIkpAKqIEpCJKQCqiBKQiSkAqzae8XvwoUwB8x5USkIooAamIEpCKKAGpiBKQiigBqfwPAwzfnkuwyTYAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 360x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_environment(env)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Looks like it's doing what we're telling it to do. Now how can we make the poll remain upright? We will need to define a _policy_ for that. This is the strategy that the agent will use to select an action at each step. It can use all the past actions and observations to decide what to do."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A simple hard-coded policy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's hard code a simple strategy: if the pole is tilting to the left, then push the cart to the left, and _vice versa_. Let's see if that works:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "frames = []\n",
    "\n",
    "n_max_steps = 1000\n",
    "n_change_steps = 10\n",
    "\n",
    "env.seed(42)\n",
    "obs = env.reset()\n",
    "for step in range(n_max_steps):\n",
    "    img = env.render(mode=\"rgb_array\")\n",
    "    frames.append(img)\n",
    "\n",
    "    # hard-coded policy\n",
    "    position, velocity, angle, angular_velocity = obs\n",
    "    if angle < 0:\n",
    "        action = 0\n",
    "    else:\n",
    "        action = 1\n",
    "\n",
    "    obs, reward, done, info = env.step(action)\n",
    "    if done:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_animation(frames)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Nope, the system is unstable and after just a few wobbles, the pole ends up too tilted: game over. We will need to be smarter than that!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Neural Network Policies"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create a neural network that will take observations as inputs, and output the action to take for each observation. To choose an action, the network will first estimate a probability for each action, then select an action randomly according to the estimated probabilities. In the case of the Cart-Pole environment, there are just two possible actions (left or right), so we only need one output neuron: it will output the probability `p` of the action 0 (left), and of course the probability of action 1 (right) will be `1 - p`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: instead of using the `fully_connected()` function from the `tensorflow.contrib.layers` module (as in the book), we now use the `dense()` function from the `tf.layers` module, which did not exist when this chapter was written. This is preferable because anything in contrib may change or be deleted without notice, while `tf.layers` is part of the official API. As you will see, the code is mostly the same.\n",
    "\n",
    "The main differences relevant to this chapter are:\n",
    "* the `_fn` suffix was removed in all the parameters that had it (for example the `activation_fn` parameter was renamed to `activation`).\n",
    "* the `weights` parameter was renamed to `kernel`,\n",
    "* the default activation is `None` instead of `tf.nn.relu`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-36-e360db0650cb>:12: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n",
      "WARNING:tensorflow:From /Users/ageron/miniconda3/envs/tf1/lib/python3.7/site-packages/tensorflow_core/python/layers/core.py:187: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "WARNING:tensorflow:From <ipython-input-36-e360db0650cb>:18: multinomial (from tensorflow.python.ops.random_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.random.categorical` instead.\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "# 1. Specify the network architecture\n",
    "n_inputs = 4  # == env.observation_space.shape[0]\n",
    "n_hidden = 4  # it's a simple task, we don't need more than this\n",
    "n_outputs = 1 # only outputs the probability of accelerating left\n",
    "initializer = tf.variance_scaling_initializer()\n",
    "\n",
    "# 2. Build the neural network\n",
    "X = tf.placeholder(tf.float32, shape=[None, n_inputs])\n",
    "hidden = tf.layers.dense(X, n_hidden, activation=tf.nn.elu,\n",
    "                         kernel_initializer=initializer)\n",
    "outputs = tf.layers.dense(hidden, n_outputs, activation=tf.nn.sigmoid,\n",
    "                          kernel_initializer=initializer)\n",
    "\n",
    "# 3. Select a random action based on the estimated probabilities\n",
    "p_left_and_right = tf.concat(axis=1, values=[outputs, 1 - outputs])\n",
    "action = tf.multinomial(tf.log(p_left_and_right), num_samples=1)\n",
    "\n",
    "init = tf.global_variables_initializer()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this particular environment, the past actions and observations can safely be ignored, since each observation contains the environment's full state. If there were some hidden state then you may need to consider past actions and observations in order to try to infer the hidden state of the environment. For example, if the environment only revealed the position of the cart but not its velocity, you would have to consider not only the current observation but also the previous observation in order to estimate the current velocity. Another example is if the observations are noisy: you may want to use the past few observations to estimate the most likely current state. Our problem is thus as simple as can be: the current observation is noise-free and contains the environment's full state."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You may wonder why we are picking a random action based on the probability given by the policy network, rather than just picking the action with the highest probability. This approach lets the agent find the right balance between _exploring_ new actions and _exploiting_ the actions that are known to work well. Here's an analogy: suppose you go to a restaurant for the first time, and all the dishes look equally appealing so you randomly pick one. If it turns out to be good, you can increase the probability to order it next time, but you shouldn't increase that probability to 100%, or else you will never try out the other dishes, some of which may be even better than the one you tried."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's randomly initialize this policy neural network and use it to play one game:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_max_steps = 1000\n",
    "frames = []\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    init.run()\n",
    "    env.seed(42)\n",
    "    obs = env.reset()\n",
    "    for step in range(n_max_steps):\n",
    "        img = env.render(mode=\"rgb_array\")\n",
    "        frames.append(img)\n",
    "        action_val = action.eval(feed_dict={X: obs.reshape(1, n_inputs)})\n",
    "        obs, reward, done, info = env.step(action_val[0][0])\n",
    "        if done:\n",
    "            break\n",
    "\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's look at how well this randomly initialized policy network performed:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_animation(frames)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Yeah... pretty bad. The neural network will have to learn to do better. First let's see if it is capable of learning the basic policy we used earlier: go left if the pole is tilting left, and go right if it is tilting right. The following code defines the same neural network but we add the target probabilities `y`, and the training operations (`cross_entropy`,  `optimizer` and `training_op`):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /Users/ageron/miniconda3/envs/tf1/lib/python3.7/site-packages/tensorflow_core/python/ops/nn_impl.py:183: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "reset_graph()\n",
    "\n",
    "n_inputs = 4\n",
    "n_hidden = 4\n",
    "n_outputs = 1\n",
    "\n",
    "learning_rate = 0.01\n",
    "\n",
    "initializer = tf.variance_scaling_initializer()\n",
    "\n",
    "X = tf.placeholder(tf.float32, shape=[None, n_inputs])\n",
    "y = tf.placeholder(tf.float32, shape=[None, n_outputs])\n",
    "\n",
    "hidden = tf.layers.dense(X, n_hidden, activation=tf.nn.elu, kernel_initializer=initializer)\n",
    "logits = tf.layers.dense(hidden, n_outputs)\n",
    "outputs = tf.nn.sigmoid(logits) # probability of action 0 (left)\n",
    "p_left_and_right = tf.concat(axis=1, values=[outputs, 1 - outputs])\n",
    "action = tf.multinomial(tf.log(p_left_and_right), num_samples=1)\n",
    "\n",
    "cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits)\n",
    "optimizer = tf.train.AdamOptimizer(learning_rate)\n",
    "training_op = optimizer.minimize(cross_entropy)\n",
    "\n",
    "init = tf.global_variables_initializer()\n",
    "saver = tf.train.Saver()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can make the same net play in 10 different environments in parallel, and train for 1000 iterations. We also reset environments when they are done."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_environments = 10\n",
    "n_iterations = 1000\n",
    "\n",
    "envs = [gym.make(\"CartPole-v0\") for _ in range(n_environments)]\n",
    "observations = [env.reset() for env in envs]\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    init.run()\n",
    "    for iteration in range(n_iterations):\n",
    "        target_probas = np.array([([1.] if obs[2] < 0 else [0.]) for obs in observations]) # if angle<0 we want proba(left)=1., or else proba(left)=0.\n",
    "        action_val, _ = sess.run([action, training_op], feed_dict={X: np.array(observations), y: target_probas})\n",
    "        for env_index, env in enumerate(envs):\n",
    "            obs, reward, done, info = env.step(action_val[env_index][0])\n",
    "            observations[env_index] = obs if not done else env.reset()\n",
    "    saver.save(sess, \"./my_policy_net_basic.ckpt\")\n",
    "\n",
    "for env in envs:\n",
    "    env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "def render_policy_net(model_path, action, X, n_max_steps=1000):\n",
    "    frames = []\n",
    "    env = gym.make(\"CartPole-v0\")\n",
    "    obs = env.reset()\n",
    "    with tf.Session() as sess:\n",
    "        saver.restore(sess, model_path)\n",
    "        for step in range(n_max_steps):\n",
    "            img = env.render(mode=\"rgb_array\")\n",
    "            frames.append(img)\n",
    "            action_val = action.eval(feed_dict={X: obs.reshape(1, n_inputs)})\n",
    "            obs, reward, done, info = env.step(action_val[0][0])\n",
    "            if done:\n",
    "                break\n",
    "    env.close()\n",
    "    return frames        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "frames = render_policy_net(\"./my_policy_net_basic.ckpt\", action, X)\n",
    "plot_animation(frames)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Looks like it learned the policy correctly. Now let's see if it can learn a better policy on its own."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Policy Gradients"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To train this neural network we will need to define the target probabilities `y`. If an action is good we should increase its probability, and conversely if it is bad we should reduce it. But how do we know whether an action is good or bad? The problem is that most actions have delayed effects, so when you win or lose points in a game, it is not clear which actions contributed to this result: was it just the last action? Or the last 10? Or just one action 50 steps earlier? This is called the _credit assignment problem_.\n",
    "\n",
    "The _Policy Gradients_ algorithm tackles this problem by first playing multiple games, then making the actions in good games slightly more likely, while actions in bad games are made slightly less likely. First we play, then we go back and think about what we did."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-46-d42d007ce403>:21: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "reset_graph()\n",
    "\n",
    "n_inputs = 4\n",
    "n_hidden = 4\n",
    "n_outputs = 1\n",
    "\n",
    "learning_rate = 0.01\n",
    "\n",
    "initializer = tf.variance_scaling_initializer()\n",
    "\n",
    "X = tf.placeholder(tf.float32, shape=[None, n_inputs])\n",
    "\n",
    "hidden = tf.layers.dense(X, n_hidden, activation=tf.nn.elu, kernel_initializer=initializer)\n",
    "logits = tf.layers.dense(hidden, n_outputs)\n",
    "outputs = tf.nn.sigmoid(logits)  # probability of action 0 (left)\n",
    "p_left_and_right = tf.concat(axis=1, values=[outputs, 1 - outputs])\n",
    "action = tf.multinomial(tf.log(p_left_and_right), num_samples=1)\n",
    "\n",
    "y = 1. - tf.to_float(action)\n",
    "cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits)\n",
    "optimizer = tf.train.AdamOptimizer(learning_rate)\n",
    "grads_and_vars = optimizer.compute_gradients(cross_entropy)\n",
    "gradients = [grad for grad, variable in grads_and_vars]\n",
    "gradient_placeholders = []\n",
    "grads_and_vars_feed = []\n",
    "for grad, variable in grads_and_vars:\n",
    "    gradient_placeholder = tf.placeholder(tf.float32, shape=grad.get_shape())\n",
    "    gradient_placeholders.append(gradient_placeholder)\n",
    "    grads_and_vars_feed.append((gradient_placeholder, variable))\n",
    "training_op = optimizer.apply_gradients(grads_and_vars_feed)\n",
    "\n",
    "init = tf.global_variables_initializer()\n",
    "saver = tf.train.Saver()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "def discount_rewards(rewards, discount_rate):\n",
    "    discounted_rewards = np.zeros(len(rewards))\n",
    "    cumulative_rewards = 0\n",
    "    for step in reversed(range(len(rewards))):\n",
    "        cumulative_rewards = rewards[step] + cumulative_rewards * discount_rate\n",
    "        discounted_rewards[step] = cumulative_rewards\n",
    "    return discounted_rewards\n",
    "\n",
    "def discount_and_normalize_rewards(all_rewards, discount_rate):\n",
    "    all_discounted_rewards = [discount_rewards(rewards, discount_rate) for rewards in all_rewards]\n",
    "    flat_rewards = np.concatenate(all_discounted_rewards)\n",
    "    reward_mean = flat_rewards.mean()\n",
    "    reward_std = flat_rewards.std()\n",
    "    return [(discounted_rewards - reward_mean)/reward_std for discounted_rewards in all_discounted_rewards]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-22., -40., -50.])"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "discount_rewards([10, 0, -50], discount_rate=0.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([-0.28435071, -0.86597718, -1.18910299]),\n",
       " array([1.26665318, 1.0727777 ])]"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "discount_and_normalize_rewards([[10, 0, -50], [10, 20]], discount_rate=0.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 249"
     ]
    }
   ],
   "source": [
    "env = gym.make(\"CartPole-v0\")\n",
    "\n",
    "n_games_per_update = 10\n",
    "n_max_steps = 1000\n",
    "n_iterations = 250\n",
    "save_iterations = 10\n",
    "discount_rate = 0.95\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    init.run()\n",
    "    for iteration in range(n_iterations):\n",
    "        print(\"\\rIteration: {}\".format(iteration), end=\"\")\n",
    "        all_rewards = []\n",
    "        all_gradients = []\n",
    "        for game in range(n_games_per_update):\n",
    "            current_rewards = []\n",
    "            current_gradients = []\n",
    "            obs = env.reset()\n",
    "            for step in range(n_max_steps):\n",
    "                action_val, gradients_val = sess.run([action, gradients], feed_dict={X: obs.reshape(1, n_inputs)})\n",
    "                obs, reward, done, info = env.step(action_val[0][0])\n",
    "                current_rewards.append(reward)\n",
    "                current_gradients.append(gradients_val)\n",
    "                if done:\n",
    "                    break\n",
    "            all_rewards.append(current_rewards)\n",
    "            all_gradients.append(current_gradients)\n",
    "\n",
    "        all_rewards = discount_and_normalize_rewards(all_rewards, discount_rate=discount_rate)\n",
    "        feed_dict = {}\n",
    "        for var_index, gradient_placeholder in enumerate(gradient_placeholders):\n",
    "            mean_gradients = np.mean([reward * all_gradients[game_index][step][var_index]\n",
    "                                      for game_index, rewards in enumerate(all_rewards)\n",
    "                                          for step, reward in enumerate(rewards)], axis=0)\n",
    "            feed_dict[gradient_placeholder] = mean_gradients\n",
    "        sess.run(training_op, feed_dict=feed_dict)\n",
    "        if iteration % save_iterations == 0:\n",
    "            saver.save(sess, \"./my_policy_net_pg.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "frames = render_policy_net(\"./my_policy_net_pg.ckpt\", action, X, n_max_steps=1000)\n",
    "plot_animation(frames)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Markov Chains"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "States: 0 0 3 \n",
      "States: 0 1 2 1 2 1 2 1 2 1 3 \n",
      "States: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n",
      "States: 0 3 \n",
      "States: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n",
      "States: 0 1 3 \n",
      "States: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 ...\n",
      "States: 0 0 3 \n",
      "States: 0 0 0 1 2 1 2 1 3 \n",
      "States: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n"
     ]
    }
   ],
   "source": [
    "transition_probabilities = [\n",
    "        [0.7, 0.2, 0.0, 0.1],  # from s0 to s0, s1, s2, s3\n",
    "        [0.0, 0.0, 0.9, 0.1],  # from s1 to ...\n",
    "        [0.0, 1.0, 0.0, 0.0],  # from s2 to ...\n",
    "        [0.0, 0.0, 0.0, 1.0],  # from s3 to ...\n",
    "    ]\n",
    "\n",
    "n_max_steps = 50\n",
    "\n",
    "def print_sequence(start_state=0):\n",
    "    current_state = start_state\n",
    "    print(\"States:\", end=\" \")\n",
    "    for step in range(n_max_steps):\n",
    "        print(current_state, end=\" \")\n",
    "        if current_state == 3:\n",
    "            break\n",
    "        current_state = np.random.choice(range(4), p=transition_probabilities[current_state])\n",
    "    else:\n",
    "        print(\"...\", end=\"\")\n",
    "    print()\n",
    "\n",
    "for _ in range(10):\n",
    "    print_sequence()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Markov Decision Process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "policy_fire\n",
      "States (+rewards): 0 (10) 0 (10) 0 1 (-50) 2 2 2 (40) 0 (10) 0 (10) 0 (10) ... Total rewards = 210\n",
      "States (+rewards): 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 1 (-50) 2 2 (40) 0 (10) ... Total rewards = 70\n",
      "States (+rewards): 0 (10) 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) ... Total rewards = 70\n",
      "States (+rewards): 0 1 (-50) 2 1 (-50) 2 (40) 0 (10) 0 1 (-50) 2 (40) 0 ... Total rewards = -10\n",
      "States (+rewards): 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 1 (-50) 2 (40) 0 (10) 0 (10) ... Total rewards = 290\n",
      "Summary: mean=121.1, std=129.333766, min=-330, max=470\n",
      "\n",
      "policy_random\n",
      "States (+rewards): 0 1 (-50) 2 1 (-50) 2 (40) 0 1 (-50) 2 2 (40) 0 ... Total rewards = -60\n",
      "States (+rewards): 0 (10) 0 0 0 0 0 (10) 0 0 0 (10) 0 ... Total rewards = -30\n",
      "States (+rewards): 0 1 1 (-50) 2 (40) 0 0 1 1 1 1 ... Total rewards = 10\n",
      "States (+rewards): 0 (10) 0 (10) 0 0 0 0 1 (-50) 2 (40) 0 0 ... Total rewards = 0\n",
      "States (+rewards): 0 0 (10) 0 1 (-50) 2 (40) 0 0 0 0 (10) 0 (10) ... Total rewards = 40\n",
      "Summary: mean=-22.1, std=88.152740, min=-380, max=200\n",
      "\n",
      "policy_safe\n",
      "States (+rewards): 0 1 1 1 1 1 1 1 1 1 ... Total rewards = 0\n",
      "States (+rewards): 0 1 1 1 1 1 1 1 1 1 ... Total rewards = 0\n",
      "States (+rewards): 0 (10) 0 (10) 0 (10) 0 1 1 1 1 1 1 ... Total rewards = 30\n",
      "States (+rewards): 0 (10) 0 1 1 1 1 1 1 1 1 ... Total rewards = 10\n",
      "States (+rewards): 0 1 1 1 1 1 1 1 1 1 ... Total rewards = 0\n",
      "Summary: mean=22.3, std=26.244312, min=0, max=170\n",
      "\n"
     ]
    }
   ],
   "source": [
    "transition_probabilities = [\n",
    "        [[0.7, 0.3, 0.0], [1.0, 0.0, 0.0], [0.8, 0.2, 0.0]], # in s0, if action a0 then proba 0.7 to state s0 and 0.3 to state s1, etc.\n",
    "        [[0.0, 1.0, 0.0], None, [0.0, 0.0, 1.0]],\n",
    "        [None, [0.8, 0.1, 0.1], None],\n",
    "    ]\n",
    "\n",
    "rewards = [\n",
    "        [[+10, 0, 0], [0, 0, 0], [0, 0, 0]],\n",
    "        [[0, 0, 0], [0, 0, 0], [0, 0, -50]],\n",
    "        [[0, 0, 0], [+40, 0, 0], [0, 0, 0]],\n",
    "    ]\n",
    "\n",
    "possible_actions = [[0, 1, 2], [0, 2], [1]]\n",
    "\n",
    "def policy_fire(state):\n",
    "    return [0, 2, 1][state]\n",
    "\n",
    "def policy_random(state):\n",
    "    return np.random.choice(possible_actions[state])\n",
    "\n",
    "def policy_safe(state):\n",
    "    return [0, 0, 1][state]\n",
    "\n",
    "class MDPEnvironment(object):\n",
    "    def __init__(self, start_state=0):\n",
    "        self.start_state=start_state\n",
    "        self.reset()\n",
    "    def reset(self):\n",
    "        self.total_rewards = 0\n",
    "        self.state = self.start_state\n",
    "    def step(self, action):\n",
    "        next_state = np.random.choice(range(3), p=transition_probabilities[self.state][action])\n",
    "        reward = rewards[self.state][action][next_state]\n",
    "        self.state = next_state\n",
    "        self.total_rewards += reward\n",
    "        return self.state, reward\n",
    "\n",
    "def run_episode(policy, n_steps, start_state=0, display=True):\n",
    "    env = MDPEnvironment()\n",
    "    if display:\n",
    "        print(\"States (+rewards):\", end=\" \")\n",
    "    for step in range(n_steps):\n",
    "        if display:\n",
    "            if step == 10:\n",
    "                print(\"...\", end=\" \")\n",
    "            elif step < 10:\n",
    "                print(env.state, end=\" \")\n",
    "        action = policy(env.state)\n",
    "        state, reward = env.step(action)\n",
    "        if display and step < 10:\n",
    "            if reward:\n",
    "                print(\"({})\".format(reward), end=\" \")\n",
    "    if display:\n",
    "        print(\"Total rewards =\", env.total_rewards)\n",
    "    return env.total_rewards\n",
    "\n",
    "for policy in (policy_fire, policy_random, policy_safe):\n",
    "    all_totals = []\n",
    "    print(policy.__name__)\n",
    "    for episode in range(1000):\n",
    "        all_totals.append(run_episode(policy, n_steps=100, display=(episode<5)))\n",
    "    print(\"Summary: mean={:.1f}, std={:1f}, min={}, max={}\".format(np.mean(all_totals), np.std(all_totals), np.min(all_totals), np.max(all_totals)))\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Q-Learning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Q-Learning works by watching an agent play (e.g., randomly) and gradually improving its estimates of the Q-Values. Once it has accurate Q-Value estimates (or close enough), then the optimal policy consists in choosing the action that has the highest Q-Value (i.e., the greedy policy)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_states = 3\n",
    "n_actions = 3\n",
    "n_steps = 20000\n",
    "alpha = 0.01\n",
    "gamma = 0.99\n",
    "exploration_policy = policy_random\n",
    "q_values = np.full((n_states, n_actions), -np.inf)\n",
    "for state, actions in enumerate(possible_actions):\n",
    "    q_values[state][actions]=0\n",
    "\n",
    "env = MDPEnvironment()\n",
    "for step in range(n_steps):\n",
    "    action = exploration_policy(env.state)\n",
    "    state = env.state\n",
    "    next_state, reward = env.step(action)\n",
    "    next_value = np.max(q_values[next_state]) # greedy policy\n",
    "    q_values[state, action] = (1-alpha)*q_values[state, action] + alpha*(reward + gamma * next_value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "def optimal_policy(state):\n",
    "    return np.argmax(q_values[state])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 39.13508139,  38.88079412,  35.23025716],\n",
       "       [ 18.9117071 ,         -inf,  20.54567816],\n",
       "       [        -inf,  72.53192111,         -inf]])"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "q_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "States (+rewards): 0 (10) 0 (10) 0 1 (-50) 2 (40) 0 (10) 0 1 (-50) 2 (40) 0 (10) ... Total rewards = 230\n",
      "States (+rewards): 0 (10) 0 (10) 0 (10) 0 1 (-50) 2 2 1 (-50) 2 (40) 0 (10) ... Total rewards = 90\n",
      "States (+rewards): 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) ... Total rewards = 170\n",
      "States (+rewards): 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) ... Total rewards = 220\n",
      "States (+rewards): 0 1 (-50) 2 (40) 0 (10) 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 (10) ... Total rewards = -50\n",
      "Summary: mean=125.6, std=127.363464, min=-290, max=500\n",
      "\n"
     ]
    }
   ],
   "source": [
    "all_totals = []\n",
    "for episode in range(1000):\n",
    "    all_totals.append(run_episode(optimal_policy, n_steps=100, display=(episode<5)))\n",
    "print(\"Summary: mean={:.1f}, std={:1f}, min={}, max={}\".format(np.mean(all_totals), np.std(all_totals), np.min(all_totals), np.max(all_totals)))\n",
    "print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning to Play MsPacman Using the DQN Algorithm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Warning**: Unfortunately, the first version of the book contained two important errors in this section.\n",
    "\n",
    "1. The actor DQN and critic DQN should have been named _online DQN_ and _target DQN_ respectively. Actor-critic algorithms are a distinct class of algorithms.\n",
    "2. The online DQN is the one that learns and is copied to the target DQN at regular intervals. The target DQN's only role is to estimate the next state's Q-Values for each possible action. This is needed to compute the target Q-Values for training the online DQN, as shown in this equation:\n",
    "\n",
    "$y(s,a) = \\text{r} + \\gamma . \\underset{a'}{\\max} \\, Q_\\text{target}(s', a')$\n",
    "\n",
    "* $y(s,a)$ is the target Q-Value to train the online DQN for the state-action pair $(s, a)$.\n",
    "* $r$ is the reward actually collected after playing action $a$ in state $s$.\n",
    "* $\\gamma$ is the discount rate.\n",
    "* $s'$ is the state actually reached after played action $a$ in state $s$.\n",
    "* $a'$ is one of the possible actions in state $s'$.\n",
    "* $Q_\\text{target}(s', a')$ is the target DQN's estimate of the Q-Value of playing action $a'$ while in state $s'$.\n",
    "\n",
    "I hope these errors did not affect you, and if they did, I sincerely apologize."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating the MsPacman environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(210, 160, 3)"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env = gym.make(\"MsPacman-v0\")\n",
    "obs = env.reset()\n",
    "obs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Discrete(9)"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.action_space"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocessing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Preprocessing the images is optional but greatly speeds up training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "mspacman_color = 210 + 164 + 74\n",
    "\n",
    "def preprocess_observation(obs):\n",
    "    img = obs[1:176:2, ::2] # crop and downsize\n",
    "    img = img.sum(axis=2) # to greyscale\n",
    "    img[img==mspacman_color] = 0 # Improve contrast\n",
    "    img = (img // 3 - 128).astype(np.int8) # normalize from -128 to 127\n",
    "    return img.reshape(88, 80, 1)\n",
    "\n",
    "img = preprocess_observation(obs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: the `preprocess_observation()` function is slightly different from the one in the book: instead of representing pixels as 64-bit floats from -1.0 to 1.0, it represents them as signed bytes (from -128 to 127). The benefit is that the replay memory will take up roughly 8 times less RAM (about 6.5 GB instead of 52 GB). The reduced precision has no visible impact on training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "/* Put everything inside the global mpl namespace */\n",
       "window.mpl = {};\n",
       "\n",
       "\n",
       "mpl.get_websocket_type = function() {\n",
       "    if (typeof(WebSocket) !== 'undefined') {\n",
       "        return WebSocket;\n",
       "    } else if (typeof(MozWebSocket) !== 'undefined') {\n",
       "        return MozWebSocket;\n",
       "    } else {\n",
       "        alert('Your browser does not have WebSocket support.' +\n",
       "              'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
       "              'Firefox 4 and 5 are also supported but you ' +\n",
       "              'have to enable WebSockets in about:config.');\n",
       "    };\n",
       "}\n",
       "\n",
       "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
       "    this.id = figure_id;\n",
       "\n",
       "    this.ws = websocket;\n",
       "\n",
       "    this.supports_binary = (this.ws.binaryType != undefined);\n",
       "\n",
       "    if (!this.supports_binary) {\n",
       "        var warnings = document.getElementById(\"mpl-warnings\");\n",
       "        if (warnings) {\n",
       "            warnings.style.display = 'block';\n",
       "            warnings.textContent = (\n",
       "                \"This browser does not support binary websocket messages. \" +\n",
       "                    \"Performance may be slow.\");\n",
       "        }\n",
       "    }\n",
       "\n",
       "    this.imageObj = new Image();\n",
       "\n",
       "    this.context = undefined;\n",
       "    this.message = undefined;\n",
       "    this.canvas = undefined;\n",
       "    this.rubberband_canvas = undefined;\n",
       "    this.rubberband_context = undefined;\n",
       "    this.format_dropdown = undefined;\n",
       "\n",
       "    this.image_mode = 'full';\n",
       "\n",
       "    this.root = $('<div/>');\n",
       "    this._root_extra_style(this.root)\n",
       "    this.root.attr('style', 'display: inline-block');\n",
       "\n",
       "    $(parent_element).append(this.root);\n",
       "\n",
       "    this._init_header(this);\n",
       "    this._init_canvas(this);\n",
       "    this._init_toolbar(this);\n",
       "\n",
       "    var fig = this;\n",
       "\n",
       "    this.waiting = false;\n",
       "\n",
       "    this.ws.onopen =  function () {\n",
       "            fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
       "            fig.send_message(\"send_image_mode\", {});\n",
       "            if (mpl.ratio != 1) {\n",
       "                fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
       "            }\n",
       "            fig.send_message(\"refresh\", {});\n",
       "        }\n",
       "\n",
       "    this.imageObj.onload = function() {\n",
       "            if (fig.image_mode == 'full') {\n",
       "                // Full images could contain transparency (where diff images\n",
       "                // almost always do), so we need to clear the canvas so that\n",
       "                // there is no ghosting.\n",
       "                fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
       "            }\n",
       "            fig.context.drawImage(fig.imageObj, 0, 0);\n",
       "        };\n",
       "\n",
       "    this.imageObj.onunload = function() {\n",
       "        fig.ws.close();\n",
       "    }\n",
       "\n",
       "    this.ws.onmessage = this._make_on_message_function(this);\n",
       "\n",
       "    this.ondownload = ondownload;\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._init_header = function() {\n",
       "    var titlebar = $(\n",
       "        '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
       "        'ui-helper-clearfix\"/>');\n",
       "    var titletext = $(\n",
       "        '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
       "        'text-align: center; padding: 3px;\"/>');\n",
       "    titlebar.append(titletext)\n",
       "    this.root.append(titlebar);\n",
       "    this.header = titletext[0];\n",
       "}\n",
       "\n",
       "\n",
       "\n",
       "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
       "\n",
       "}\n",
       "\n",
       "\n",
       "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
       "\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._init_canvas = function() {\n",
       "    var fig = this;\n",
       "\n",
       "    var canvas_div = $('<div/>');\n",
       "\n",
       "    canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
       "\n",
       "    function canvas_keyboard_event(event) {\n",
       "        return fig.key_event(event, event['data']);\n",
       "    }\n",
       "\n",
       "    canvas_div.keydown('key_press', canvas_keyboard_event);\n",
       "    canvas_div.keyup('key_release', canvas_keyboard_event);\n",
       "    this.canvas_div = canvas_div\n",
       "    this._canvas_extra_style(canvas_div)\n",
       "    this.root.append(canvas_div);\n",
       "\n",
       "    var canvas = $('<canvas/>');\n",
       "    canvas.addClass('mpl-canvas');\n",
       "    canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
       "\n",
       "    this.canvas = canvas[0];\n",
       "    this.context = canvas[0].getContext(\"2d\");\n",
       "\n",
       "    var backingStore = this.context.backingStorePixelRatio ||\n",
       "\tthis.context.webkitBackingStorePixelRatio ||\n",
       "\tthis.context.mozBackingStorePixelRatio ||\n",
       "\tthis.context.msBackingStorePixelRatio ||\n",
       "\tthis.context.oBackingStorePixelRatio ||\n",
       "\tthis.context.backingStorePixelRatio || 1;\n",
       "\n",
       "    mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
       "\n",
       "    var rubberband = $('<canvas/>');\n",
       "    rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
       "\n",
       "    var pass_mouse_events = true;\n",
       "\n",
       "    canvas_div.resizable({\n",
       "        start: function(event, ui) {\n",
       "            pass_mouse_events = false;\n",
       "        },\n",
       "        resize: function(event, ui) {\n",
       "            fig.request_resize(ui.size.width, ui.size.height);\n",
       "        },\n",
       "        stop: function(event, ui) {\n",
       "            pass_mouse_events = true;\n",
       "            fig.request_resize(ui.size.width, ui.size.height);\n",
       "        },\n",
       "    });\n",
       "\n",
       "    function mouse_event_fn(event) {\n",
       "        if (pass_mouse_events)\n",
       "            return fig.mouse_event(event, event['data']);\n",
       "    }\n",
       "\n",
       "    rubberband.mousedown('button_press', mouse_event_fn);\n",
       "    rubberband.mouseup('button_release', mouse_event_fn);\n",
       "    // Throttle sequential mouse events to 1 every 20ms.\n",
       "    rubberband.mousemove('motion_notify', mouse_event_fn);\n",
       "\n",
       "    rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
       "    rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
       "\n",
       "    canvas_div.on(\"wheel\", function (event) {\n",
       "        event = event.originalEvent;\n",
       "        event['data'] = 'scroll'\n",
       "        if (event.deltaY < 0) {\n",
       "            event.step = 1;\n",
       "        } else {\n",
       "            event.step = -1;\n",
       "        }\n",
       "        mouse_event_fn(event);\n",
       "    });\n",
       "\n",
       "    canvas_div.append(canvas);\n",
       "    canvas_div.append(rubberband);\n",
       "\n",
       "    this.rubberband = rubberband;\n",
       "    this.rubberband_canvas = rubberband[0];\n",
       "    this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
       "    this.rubberband_context.strokeStyle = \"#000000\";\n",
       "\n",
       "    this._resize_canvas = function(width, height) {\n",
       "        // Keep the size of the canvas, canvas container, and rubber band\n",
       "        // canvas in synch.\n",
       "        canvas_div.css('width', width)\n",
       "        canvas_div.css('height', height)\n",
       "\n",
       "        canvas.attr('width', width * mpl.ratio);\n",
       "        canvas.attr('height', height * mpl.ratio);\n",
       "        canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
       "\n",
       "        rubberband.attr('width', width);\n",
       "        rubberband.attr('height', height);\n",
       "    }\n",
       "\n",
       "    // Set the figure to an initial 600x600px, this will subsequently be updated\n",
       "    // upon first draw.\n",
       "    this._resize_canvas(600, 600);\n",
       "\n",
       "    // Disable right mouse context menu.\n",
       "    $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
       "        return false;\n",
       "    });\n",
       "\n",
       "    function set_focus () {\n",
       "        canvas.focus();\n",
       "        canvas_div.focus();\n",
       "    }\n",
       "\n",
       "    window.setTimeout(set_focus, 100);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._init_toolbar = function() {\n",
       "    var fig = this;\n",
       "\n",
       "    var nav_element = $('<div/>')\n",
       "    nav_element.attr('style', 'width: 100%');\n",
       "    this.root.append(nav_element);\n",
       "\n",
       "    // Define a callback function for later on.\n",
       "    function toolbar_event(event) {\n",
       "        return fig.toolbar_button_onclick(event['data']);\n",
       "    }\n",
       "    function toolbar_mouse_event(event) {\n",
       "        return fig.toolbar_button_onmouseover(event['data']);\n",
       "    }\n",
       "\n",
       "    for(var toolbar_ind in mpl.toolbar_items) {\n",
       "        var name = mpl.toolbar_items[toolbar_ind][0];\n",
       "        var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
       "        var image = mpl.toolbar_items[toolbar_ind][2];\n",
       "        var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
       "\n",
       "        if (!name) {\n",
       "            // put a spacer in here.\n",
       "            continue;\n",
       "        }\n",
       "        var button = $('<button/>');\n",
       "        button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
       "                        'ui-button-icon-only');\n",
       "        button.attr('role', 'button');\n",
       "        button.attr('aria-disabled', 'false');\n",
       "        button.click(method_name, toolbar_event);\n",
       "        button.mouseover(tooltip, toolbar_mouse_event);\n",
       "\n",
       "        var icon_img = $('<span/>');\n",
       "        icon_img.addClass('ui-button-icon-primary ui-icon');\n",
       "        icon_img.addClass(image);\n",
       "        icon_img.addClass('ui-corner-all');\n",
       "\n",
       "        var tooltip_span = $('<span/>');\n",
       "        tooltip_span.addClass('ui-button-text');\n",
       "        tooltip_span.html(tooltip);\n",
       "\n",
       "        button.append(icon_img);\n",
       "        button.append(tooltip_span);\n",
       "\n",
       "        nav_element.append(button);\n",
       "    }\n",
       "\n",
       "    var fmt_picker_span = $('<span/>');\n",
       "\n",
       "    var fmt_picker = $('<select/>');\n",
       "    fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
       "    fmt_picker_span.append(fmt_picker);\n",
       "    nav_element.append(fmt_picker_span);\n",
       "    this.format_dropdown = fmt_picker[0];\n",
       "\n",
       "    for (var ind in mpl.extensions) {\n",
       "        var fmt = mpl.extensions[ind];\n",
       "        var option = $(\n",
       "            '<option/>', {selected: fmt === mpl.default_extension}).html(fmt);\n",
       "        fmt_picker.append(option)\n",
       "    }\n",
       "\n",
       "    // Add hover states to the ui-buttons\n",
       "    $( \".ui-button\" ).hover(\n",
       "        function() { $(this).addClass(\"ui-state-hover\");},\n",
       "        function() { $(this).removeClass(\"ui-state-hover\");}\n",
       "    );\n",
       "\n",
       "    var status_bar = $('<span class=\"mpl-message\"/>');\n",
       "    nav_element.append(status_bar);\n",
       "    this.message = status_bar[0];\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
       "    // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
       "    // which will in turn request a refresh of the image.\n",
       "    this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.send_message = function(type, properties) {\n",
       "    properties['type'] = type;\n",
       "    properties['figure_id'] = this.id;\n",
       "    this.ws.send(JSON.stringify(properties));\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.send_draw_message = function() {\n",
       "    if (!this.waiting) {\n",
       "        this.waiting = true;\n",
       "        this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
       "    }\n",
       "}\n",
       "\n",
       "\n",
       "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
       "    var format_dropdown = fig.format_dropdown;\n",
       "    var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
       "    fig.ondownload(fig, format);\n",
       "}\n",
       "\n",
       "\n",
       "mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
       "    var size = msg['size'];\n",
       "    if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
       "        fig._resize_canvas(size[0], size[1]);\n",
       "        fig.send_message(\"refresh\", {});\n",
       "    };\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
       "    var x0 = msg['x0'] / mpl.ratio;\n",
       "    var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
       "    var x1 = msg['x1'] / mpl.ratio;\n",
       "    var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
       "    x0 = Math.floor(x0) + 0.5;\n",
       "    y0 = Math.floor(y0) + 0.5;\n",
       "    x1 = Math.floor(x1) + 0.5;\n",
       "    y1 = Math.floor(y1) + 0.5;\n",
       "    var min_x = Math.min(x0, x1);\n",
       "    var min_y = Math.min(y0, y1);\n",
       "    var width = Math.abs(x1 - x0);\n",
       "    var height = Math.abs(y1 - y0);\n",
       "\n",
       "    fig.rubberband_context.clearRect(\n",
       "        0, 0, fig.canvas.width, fig.canvas.height);\n",
       "\n",
       "    fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
       "    // Updates the figure title.\n",
       "    fig.header.textContent = msg['label'];\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
       "    var cursor = msg['cursor'];\n",
       "    switch(cursor)\n",
       "    {\n",
       "    case 0:\n",
       "        cursor = 'pointer';\n",
       "        break;\n",
       "    case 1:\n",
       "        cursor = 'default';\n",
       "        break;\n",
       "    case 2:\n",
       "        cursor = 'crosshair';\n",
       "        break;\n",
       "    case 3:\n",
       "        cursor = 'move';\n",
       "        break;\n",
       "    }\n",
       "    fig.rubberband_canvas.style.cursor = cursor;\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_message = function(fig, msg) {\n",
       "    fig.message.textContent = msg['message'];\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
       "    // Request the server to send over a new figure.\n",
       "    fig.send_draw_message();\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
       "    fig.image_mode = msg['mode'];\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.updated_canvas_event = function() {\n",
       "    // Called whenever the canvas gets updated.\n",
       "    this.send_message(\"ack\", {});\n",
       "}\n",
       "\n",
       "// A function to construct a web socket function for onmessage handling.\n",
       "// Called in the figure constructor.\n",
       "mpl.figure.prototype._make_on_message_function = function(fig) {\n",
       "    return function socket_on_message(evt) {\n",
       "        if (evt.data instanceof Blob) {\n",
       "            /* FIXME: We get \"Resource interpreted as Image but\n",
       "             * transferred with MIME type text/plain:\" errors on\n",
       "             * Chrome.  But how to set the MIME type?  It doesn't seem\n",
       "             * to be part of the websocket stream */\n",
       "            evt.data.type = \"image/png\";\n",
       "\n",
       "            /* Free the memory for the previous frames */\n",
       "            if (fig.imageObj.src) {\n",
       "                (window.URL || window.webkitURL).revokeObjectURL(\n",
       "                    fig.imageObj.src);\n",
       "            }\n",
       "\n",
       "            fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
       "                evt.data);\n",
       "            fig.updated_canvas_event();\n",
       "            fig.waiting = false;\n",
       "            return;\n",
       "        }\n",
       "        else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
       "            fig.imageObj.src = evt.data;\n",
       "            fig.updated_canvas_event();\n",
       "            fig.waiting = false;\n",
       "            return;\n",
       "        }\n",
       "\n",
       "        var msg = JSON.parse(evt.data);\n",
       "        var msg_type = msg['type'];\n",
       "\n",
       "        // Call the  \"handle_{type}\" callback, which takes\n",
       "        // the figure and JSON message as its only arguments.\n",
       "        try {\n",
       "            var callback = fig[\"handle_\" + msg_type];\n",
       "        } catch (e) {\n",
       "            console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
       "            return;\n",
       "        }\n",
       "\n",
       "        if (callback) {\n",
       "            try {\n",
       "                // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
       "                callback(fig, msg);\n",
       "            } catch (e) {\n",
       "                console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
       "            }\n",
       "        }\n",
       "    };\n",
       "}\n",
       "\n",
       "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
       "mpl.findpos = function(e) {\n",
       "    //this section is from http://www.quirksmode.org/js/events_properties.html\n",
       "    var targ;\n",
       "    if (!e)\n",
       "        e = window.event;\n",
       "    if (e.target)\n",
       "        targ = e.target;\n",
       "    else if (e.srcElement)\n",
       "        targ = e.srcElement;\n",
       "    if (targ.nodeType == 3) // defeat Safari bug\n",
       "        targ = targ.parentNode;\n",
       "\n",
       "    // jQuery normalizes the pageX and pageY\n",
       "    // pageX,Y are the mouse positions relative to the document\n",
       "    // offset() returns the position of the element relative to the document\n",
       "    var x = e.pageX - $(targ).offset().left;\n",
       "    var y = e.pageY - $(targ).offset().top;\n",
       "\n",
       "    return {\"x\": x, \"y\": y};\n",
       "};\n",
       "\n",
       "/*\n",
       " * return a copy of an object with only non-object keys\n",
       " * we need this to avoid circular references\n",
       " * http://stackoverflow.com/a/24161582/3208463\n",
       " */\n",
       "function simpleKeys (original) {\n",
       "  return Object.keys(original).reduce(function (obj, key) {\n",
       "    if (typeof original[key] !== 'object')\n",
       "        obj[key] = original[key]\n",
       "    return obj;\n",
       "  }, {});\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.mouse_event = function(event, name) {\n",
       "    var canvas_pos = mpl.findpos(event)\n",
       "\n",
       "    if (name === 'button_press')\n",
       "    {\n",
       "        this.canvas.focus();\n",
       "        this.canvas_div.focus();\n",
       "    }\n",
       "\n",
       "    var x = canvas_pos.x * mpl.ratio;\n",
       "    var y = canvas_pos.y * mpl.ratio;\n",
       "\n",
       "    this.send_message(name, {x: x, y: y, button: event.button,\n",
       "                             step: event.step,\n",
       "                             guiEvent: simpleKeys(event)});\n",
       "\n",
       "    /* This prevents the web browser from automatically changing to\n",
       "     * the text insertion cursor when the button is pressed.  We want\n",
       "     * to control all of the cursor setting manually through the\n",
       "     * 'cursor' event from matplotlib */\n",
       "    event.preventDefault();\n",
       "    return false;\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
       "    // Handle any extra behaviour associated with a key event\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.key_event = function(event, name) {\n",
       "\n",
       "    // Prevent repeat events\n",
       "    if (name == 'key_press')\n",
       "    {\n",
       "        if (event.which === this._key)\n",
       "            return;\n",
       "        else\n",
       "            this._key = event.which;\n",
       "    }\n",
       "    if (name == 'key_release')\n",
       "        this._key = null;\n",
       "\n",
       "    var value = '';\n",
       "    if (event.ctrlKey && event.which != 17)\n",
       "        value += \"ctrl+\";\n",
       "    if (event.altKey && event.which != 18)\n",
       "        value += \"alt+\";\n",
       "    if (event.shiftKey && event.which != 16)\n",
       "        value += \"shift+\";\n",
       "\n",
       "    value += 'k';\n",
       "    value += event.which.toString();\n",
       "\n",
       "    this._key_event_extra(event, name);\n",
       "\n",
       "    this.send_message(name, {key: value,\n",
       "                             guiEvent: simpleKeys(event)});\n",
       "    return false;\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
       "    if (name == 'download') {\n",
       "        this.handle_save(this, null);\n",
       "    } else {\n",
       "        this.send_message(\"toolbar_button\", {name: name});\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
       "    this.message.textContent = tooltip;\n",
       "};\n",
       "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to  previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
       "\n",
       "mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
       "\n",
       "mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
       "    // Create a \"websocket\"-like object which calls the given IPython comm\n",
       "    // object with the appropriate methods. Currently this is a non binary\n",
       "    // socket, so there is still some room for performance tuning.\n",
       "    var ws = {};\n",
       "\n",
       "    ws.close = function() {\n",
       "        comm.close()\n",
       "    };\n",
       "    ws.send = function(m) {\n",
       "        //console.log('sending', m);\n",
       "        comm.send(m);\n",
       "    };\n",
       "    // Register the callback with on_msg.\n",
       "    comm.on_msg(function(msg) {\n",
       "        //console.log('receiving', msg['content']['data'], msg);\n",
       "        // Pass the mpl event to the overriden (by mpl) onmessage function.\n",
       "        ws.onmessage(msg['content']['data'])\n",
       "    });\n",
       "    return ws;\n",
       "}\n",
       "\n",
       "mpl.mpl_figure_comm = function(comm, msg) {\n",
       "    // This is the function which gets called when the mpl process\n",
       "    // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
       "\n",
       "    var id = msg.content.data.id;\n",
       "    // Get hold of the div created by the display call when the Comm\n",
       "    // socket was opened in Python.\n",
       "    var element = $(\"#\" + id);\n",
       "    var ws_proxy = comm_websocket_adapter(comm)\n",
       "\n",
       "    function ondownload(figure, format) {\n",
       "        window.open(figure.imageObj.src);\n",
       "    }\n",
       "\n",
       "    var fig = new mpl.figure(id, ws_proxy,\n",
       "                           ondownload,\n",
       "                           element.get(0));\n",
       "\n",
       "    // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
       "    // web socket which is closed, not our websocket->open comm proxy.\n",
       "    ws_proxy.onopen();\n",
       "\n",
       "    fig.parent_element = element.get(0);\n",
       "    fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
       "    if (!fig.cell_info) {\n",
       "        console.error(\"Failed to find cell for figure\", id, fig);\n",
       "        return;\n",
       "    }\n",
       "\n",
       "    var output_index = fig.cell_info[2]\n",
       "    var cell = fig.cell_info[0];\n",
       "\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_close = function(fig, msg) {\n",
       "    var width = fig.canvas.width/mpl.ratio\n",
       "    fig.root.unbind('remove')\n",
       "\n",
       "    // Update the output cell to use the data from the current canvas.\n",
       "    fig.push_to_output();\n",
       "    var dataURL = fig.canvas.toDataURL();\n",
       "    // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
       "    // the notebook keyboard shortcuts fail.\n",
       "    IPython.keyboard_manager.enable()\n",
       "    $(fig.parent_element).html('<img src=\"' + dataURL + '\" width=\"' + width + '\">');\n",
       "    fig.close_ws(fig, msg);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.close_ws = function(fig, msg){\n",
       "    fig.send_message('closing', msg);\n",
       "    // fig.ws.close()\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
       "    // Turn the data on the canvas into data in the output cell.\n",
       "    var width = this.canvas.width/mpl.ratio\n",
       "    var dataURL = this.canvas.toDataURL();\n",
       "    this.cell_info[1]['text/html'] = '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.updated_canvas_event = function() {\n",
       "    // Tell IPython that the notebook contents must change.\n",
       "    IPython.notebook.set_dirty(true);\n",
       "    this.send_message(\"ack\", {});\n",
       "    var fig = this;\n",
       "    // Wait a second, then push the new image to the DOM so\n",
       "    // that it is saved nicely (might be nice to debounce this).\n",
       "    setTimeout(function () { fig.push_to_output() }, 1000);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._init_toolbar = function() {\n",
       "    var fig = this;\n",
       "\n",
       "    var nav_element = $('<div/>')\n",
       "    nav_element.attr('style', 'width: 100%');\n",
       "    this.root.append(nav_element);\n",
       "\n",
       "    // Define a callback function for later on.\n",
       "    function toolbar_event(event) {\n",
       "        return fig.toolbar_button_onclick(event['data']);\n",
       "    }\n",
       "    function toolbar_mouse_event(event) {\n",
       "        return fig.toolbar_button_onmouseover(event['data']);\n",
       "    }\n",
       "\n",
       "    for(var toolbar_ind in mpl.toolbar_items){\n",
       "        var name = mpl.toolbar_items[toolbar_ind][0];\n",
       "        var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
       "        var image = mpl.toolbar_items[toolbar_ind][2];\n",
       "        var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
       "\n",
       "        if (!name) { continue; };\n",
       "\n",
       "        var button = $('<button class=\"btn btn-default\" href=\"#\" title=\"' + name + '\"><i class=\"fa ' + image + ' fa-lg\"></i></button>');\n",
       "        button.click(method_name, toolbar_event);\n",
       "        button.mouseover(tooltip, toolbar_mouse_event);\n",
       "        nav_element.append(button);\n",
       "    }\n",
       "\n",
       "    // Add the status bar.\n",
       "    var status_bar = $('<span class=\"mpl-message\" style=\"text-align:right; float: right;\"/>');\n",
       "    nav_element.append(status_bar);\n",
       "    this.message = status_bar[0];\n",
       "\n",
       "    // Add the close button to the window.\n",
       "    var buttongrp = $('<div class=\"btn-group inline pull-right\"></div>');\n",
       "    var button = $('<button class=\"btn btn-mini btn-primary\" href=\"#\" title=\"Stop Interaction\"><i class=\"fa fa-power-off icon-remove icon-large\"></i></button>');\n",
       "    button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
       "    button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
       "    buttongrp.append(button);\n",
       "    var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
       "    titlebar.prepend(buttongrp);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._root_extra_style = function(el){\n",
       "    var fig = this\n",
       "    el.on(\"remove\", function(){\n",
       "\tfig.close_ws(fig, {});\n",
       "    });\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._canvas_extra_style = function(el){\n",
       "    // this is important to make the div 'focusable\n",
       "    el.attr('tabindex', 0)\n",
       "    // reach out to IPython and tell the keyboard manager to turn it's self\n",
       "    // off when our div gets focus\n",
       "\n",
       "    // location in version 3\n",
       "    if (IPython.notebook.keyboard_manager) {\n",
       "        IPython.notebook.keyboard_manager.register_events(el);\n",
       "    }\n",
       "    else {\n",
       "        // location in version 2\n",
       "        IPython.keyboard_manager.register_events(el);\n",
       "    }\n",
       "\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
       "    var manager = IPython.notebook.keyboard_manager;\n",
       "    if (!manager)\n",
       "        manager = IPython.keyboard_manager;\n",
       "\n",
       "    // Check for shift+enter\n",
       "    if (event.shiftKey && event.which == 13) {\n",
       "        this.canvas_div.blur();\n",
       "        event.shiftKey = false;\n",
       "        // Send a \"J\" for go to next cell\n",
       "        event.which = 74;\n",
       "        event.keyCode = 74;\n",
       "        manager.command_mode();\n",
       "        manager.handle_keydown(event);\n",
       "    }\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
       "    fig.ondownload(fig, null);\n",
       "}\n",
       "\n",
       "\n",
       "mpl.find_output_cell = function(html_output) {\n",
       "    // Return the cell and output element which can be found *uniquely* in the notebook.\n",
       "    // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
       "    // IPython event is triggered only after the cells have been serialised, which for\n",
       "    // our purposes (turning an active figure into a static one), is too late.\n",
       "    var cells = IPython.notebook.get_cells();\n",
       "    var ncells = cells.length;\n",
       "    for (var i=0; i<ncells; i++) {\n",
       "        var cell = cells[i];\n",
       "        if (cell.cell_type === 'code'){\n",
       "            for (var j=0; j<cell.output_area.outputs.length; j++) {\n",
       "                var data = cell.output_area.outputs[j];\n",
       "                if (data.data) {\n",
       "                    // IPython >= 3 moved mimebundle to data attribute of output\n",
       "                    data = data.data;\n",
       "                }\n",
       "                if (data['text/html'] == html_output) {\n",
       "                    return [cell, data, j];\n",
       "                }\n",
       "            }\n",
       "        }\n",
       "    }\n",
       "}\n",
       "\n",
       "// Register the function which deals with the matplotlib target/channel.\n",
       "// The kernel may be null if the page has been refreshed.\n",
       "if (IPython.notebook.kernel != null) {\n",
       "    IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
       "}\n"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<img src=\"\" width=\"1100\">"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving figure preprocessing_plot\n"
     ]
    }
   ],
   "source": [
    "plt.figure(figsize=(11, 7))\n",
    "plt.subplot(121)\n",
    "plt.title(\"Original observation (160×210 RGB)\")\n",
    "plt.imshow(obs)\n",
    "plt.axis(\"off\")\n",
    "plt.subplot(122)\n",
    "plt.title(\"Preprocessed observation (88×80 greyscale)\")\n",
    "plt.imshow(img.reshape(88, 80), interpolation=\"nearest\", cmap=\"gray\")\n",
    "plt.axis(\"off\")\n",
    "save_fig(\"preprocessing_plot\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build DQN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: instead of using `tf.contrib.layers.convolution2d()` or `tf.contrib.layers.conv2d()` (as in the first version of the book), we now use the `tf.layers.conv2d()`, which did not exist when this chapter was written. This is preferable because anything in contrib may change or be deleted without notice, while `tf.layers` is part of the official API. As you will see, the code is mostly the same, except that the parameter names have changed slightly:\n",
    "* the `num_outputs` parameter was renamed to `filters`,\n",
    "* the `stride` parameter was renamed to `strides`,\n",
    "* the `_fn` suffix was removed from parameter names that had it (e.g., `activation_fn` was renamed to `activation`),\n",
    "* the `weights_initializer` parameter was renamed to `kernel_initializer`,\n",
    "* the weights variable was renamed to `\"kernel\"` (instead of `\"weights\"`), and the biases variable was renamed from `\"biases\"` to `\"bias\"`,\n",
    "* and the default `activation` is now `None` instead of `tf.nn.relu`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "reset_graph()\n",
    "\n",
    "input_height = 88\n",
    "input_width = 80\n",
    "input_channels = 1\n",
    "conv_n_maps = [32, 64, 64]\n",
    "conv_kernel_sizes = [(8,8), (4,4), (3,3)]\n",
    "conv_strides = [4, 2, 1]\n",
    "conv_paddings = [\"SAME\"] * 3 \n",
    "conv_activation = [tf.nn.relu] * 3\n",
    "n_hidden_in = 64 * 11 * 10  # conv3 has 64 maps of 11x10 each\n",
    "n_hidden = 512\n",
    "hidden_activation = tf.nn.relu\n",
    "n_outputs = env.action_space.n  # 9 discrete actions are available\n",
    "initializer = tf.variance_scaling_initializer()\n",
    "\n",
    "def q_network(X_state, name):\n",
    "    prev_layer = X_state / 128.0 # scale pixel intensities to the [-1.0, 1.0] range.\n",
    "    with tf.variable_scope(name) as scope:\n",
    "        for n_maps, kernel_size, strides, padding, activation in zip(\n",
    "                conv_n_maps, conv_kernel_sizes, conv_strides,\n",
    "                conv_paddings, conv_activation):\n",
    "            prev_layer = tf.layers.conv2d(\n",
    "                prev_layer, filters=n_maps, kernel_size=kernel_size,\n",
    "                strides=strides, padding=padding, activation=activation,\n",
    "                kernel_initializer=initializer)\n",
    "        last_conv_layer_flat = tf.reshape(prev_layer, shape=[-1, n_hidden_in])\n",
    "        hidden = tf.layers.dense(last_conv_layer_flat, n_hidden,\n",
    "                                 activation=hidden_activation,\n",
    "                                 kernel_initializer=initializer)\n",
    "        outputs = tf.layers.dense(hidden, n_outputs,\n",
    "                                  kernel_initializer=initializer)\n",
    "    trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,\n",
    "                                       scope=scope.name)\n",
    "    trainable_vars_by_name = {var.name[len(scope.name):]: var\n",
    "                              for var in trainable_vars}\n",
    "    return outputs, trainable_vars_by_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_state = tf.placeholder(tf.float32, shape=[None, input_height, input_width,\n",
    "                                            input_channels])\n",
    "online_q_values, online_vars = q_network(X_state, name=\"q_networks/online\")\n",
    "target_q_values, target_vars = q_network(X_state, name=\"q_networks/target\")\n",
    "\n",
    "copy_ops = [target_var.assign(online_vars[var_name])\n",
    "            for var_name, target_var in target_vars.items()]\n",
    "copy_online_to_target = tf.group(*copy_ops)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'/conv2d/bias:0': <tf.Variable 'q_networks/online/conv2d/bias:0' shape=(32,) dtype=float32_ref>,\n",
       " '/conv2d/kernel:0': <tf.Variable 'q_networks/online/conv2d/kernel:0' shape=(8, 8, 1, 32) dtype=float32_ref>,\n",
       " '/conv2d_1/bias:0': <tf.Variable 'q_networks/online/conv2d_1/bias:0' shape=(64,) dtype=float32_ref>,\n",
       " '/conv2d_1/kernel:0': <tf.Variable 'q_networks/online/conv2d_1/kernel:0' shape=(4, 4, 32, 64) dtype=float32_ref>,\n",
       " '/conv2d_2/bias:0': <tf.Variable 'q_networks/online/conv2d_2/bias:0' shape=(64,) dtype=float32_ref>,\n",
       " '/conv2d_2/kernel:0': <tf.Variable 'q_networks/online/conv2d_2/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,\n",
       " '/dense/bias:0': <tf.Variable 'q_networks/online/dense/bias:0' shape=(512,) dtype=float32_ref>,\n",
       " '/dense/kernel:0': <tf.Variable 'q_networks/online/dense/kernel:0' shape=(7040, 512) dtype=float32_ref>,\n",
       " '/dense_1/bias:0': <tf.Variable 'q_networks/online/dense_1/bias:0' shape=(9,) dtype=float32_ref>,\n",
       " '/dense_1/kernel:0': <tf.Variable 'q_networks/online/dense_1/kernel:0' shape=(512, 9) dtype=float32_ref>}"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "online_vars"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rate = 0.001\n",
    "momentum = 0.95\n",
    "\n",
    "with tf.variable_scope(\"train\"):\n",
    "    X_action = tf.placeholder(tf.int32, shape=[None])\n",
    "    y = tf.placeholder(tf.float32, shape=[None, 1])\n",
    "    q_value = tf.reduce_sum(online_q_values * tf.one_hot(X_action, n_outputs),\n",
    "                            axis=1, keepdims=True)\n",
    "    error = tf.abs(y - q_value)\n",
    "    clipped_error = tf.clip_by_value(error, 0.0, 1.0)\n",
    "    linear_error = 2 * (error - clipped_error)\n",
    "    loss = tf.reduce_mean(tf.square(clipped_error) + linear_error)\n",
    "\n",
    "    global_step = tf.Variable(0, trainable=False, name='global_step')\n",
    "    optimizer = tf.train.MomentumOptimizer(learning_rate, momentum, use_nesterov=True)\n",
    "    training_op = optimizer.minimize(loss, global_step=global_step)\n",
    "\n",
    "init = tf.global_variables_initializer()\n",
    "saver = tf.train.Saver()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: in the first version of the book, the loss function was simply the squared error between the target Q-Values (`y`) and the estimated Q-Values (`q_value`). However, because the experiences are very noisy, it is better to use a quadratic loss only for small errors (below 1.0) and a linear loss (twice the absolute error) for larger errors, which is what the code above computes. This way large errors don't push the model parameters around as much. Note that we also tweaked some hyperparameters (using a smaller learning rate, and using Nesterov Accelerated Gradients rather than Adam optimization, since adaptive gradient algorithms may sometimes be bad, according to this [paper](https://arxiv.org/abs/1705.08292)). We also tweaked a few other hyperparameters below (a larger replay memory, longer decay for the $\\epsilon$-greedy policy, larger discount rate, less frequent copies of the online DQN to the target DQN, etc.)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use this `ReplayMemory` class instead of a `deque` because it is much faster for random access (thanks to @NileshPS who contributed it). Moreover, we default to sampling with replacement, which is much faster than sampling without replacement for large replay memories."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReplayMemory:\n",
    "    def __init__(self, maxlen):\n",
    "        self.maxlen = maxlen\n",
    "        self.buf = np.empty(shape=maxlen, dtype=np.object)\n",
    "        self.index = 0\n",
    "        self.length = 0\n",
    "        \n",
    "    def append(self, data):\n",
    "        self.buf[self.index] = data\n",
    "        self.length = min(self.length + 1, self.maxlen)\n",
    "        self.index = (self.index + 1) % self.maxlen\n",
    "    \n",
    "    def sample(self, batch_size, with_replacement=True):\n",
    "        if with_replacement:\n",
    "            indices = np.random.randint(self.length, size=batch_size) # faster\n",
    "        else:\n",
    "            indices = np.random.permutation(self.length)[:batch_size]\n",
    "        return self.buf[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "replay_memory_size = 500000\n",
    "replay_memory = ReplayMemory(replay_memory_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_memories(batch_size):\n",
    "    cols = [[], [], [], [], []] # state, action, reward, next_state, continue\n",
    "    for memory in replay_memory.sample(batch_size):\n",
    "        for col, value in zip(cols, memory):\n",
    "            col.append(value)\n",
    "    cols = [np.array(col) for col in cols]\n",
    "    return cols[0], cols[1], cols[2].reshape(-1, 1), cols[3], cols[4].reshape(-1, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "eps_min = 0.1\n",
    "eps_max = 1.0\n",
    "eps_decay_steps = 2000000\n",
    "\n",
    "def epsilon_greedy(q_values, step):\n",
    "    epsilon = max(eps_min, eps_max - (eps_max-eps_min) * step/eps_decay_steps)\n",
    "    if np.random.rand() < epsilon:\n",
    "        return np.random.randint(n_outputs) # random action\n",
    "    else:\n",
    "        return np.argmax(q_values) # optimal action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_steps = 4000000  # total number of training steps\n",
    "training_start = 10000  # start training after 10,000 game iterations\n",
    "training_interval = 4  # run a training step every 4 game iterations\n",
    "save_steps = 1000  # save the model every 1,000 training steps\n",
    "copy_steps = 10000  # copy online DQN to target DQN every 10,000 training steps\n",
    "discount_rate = 0.99\n",
    "skip_start = 90  # Skip the start of every game (it's just waiting time).\n",
    "batch_size = 50\n",
    "iteration = 0  # game iterations\n",
    "checkpoint_path = \"./my_dqn.ckpt\"\n",
    "done = True # env needs to be reset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A few variables for tracking progress:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_val = np.infty\n",
    "game_length = 0\n",
    "total_max_q = 0\n",
    "mean_max_q = 0.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now the main training loop!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ./my_dqn.ckpt\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2017-09-25 13:55:15,610] Restoring parameters from ./my_dqn.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Iteration 1270171\tTraining step 315001/4000000 (7.9)%\tLoss 2.651937\tMean Max-Q 30.964941   "
     ]
    }
   ],
   "source": [
    "with tf.Session() as sess:\n",
    "    if os.path.isfile(checkpoint_path + \".index\"):\n",
    "        saver.restore(sess, checkpoint_path)\n",
    "    else:\n",
    "        init.run()\n",
    "        copy_online_to_target.run()\n",
    "    while True:\n",
    "        step = global_step.eval()\n",
    "        if step >= n_steps:\n",
    "            break\n",
    "        iteration += 1\n",
    "        print(\"\\rIteration {}\\tTraining step {}/{} ({:.1f})%\\tLoss {:5f}\\tMean Max-Q {:5f}   \".format(\n",
    "            iteration, step, n_steps, step * 100 / n_steps, loss_val, mean_max_q), end=\"\")\n",
    "        if done: # game over, start again\n",
    "            obs = env.reset()\n",
    "            for skip in range(skip_start): # skip the start of each game\n",
    "                obs, reward, done, info = env.step(0)\n",
    "            state = preprocess_observation(obs)\n",
    "\n",
    "        # Online DQN evaluates what to do\n",
    "        q_values = online_q_values.eval(feed_dict={X_state: [state]})\n",
    "        action = epsilon_greedy(q_values, step)\n",
    "\n",
    "        # Online DQN plays\n",
    "        obs, reward, done, info = env.step(action)\n",
    "        next_state = preprocess_observation(obs)\n",
    "\n",
    "        # Let's memorize what happened\n",
    "        replay_memory.append((state, action, reward, next_state, 1.0 - done))\n",
    "        state = next_state\n",
    "\n",
    "        # Compute statistics for tracking progress (not shown in the book)\n",
    "        total_max_q += q_values.max()\n",
    "        game_length += 1\n",
    "        if done:\n",
    "            mean_max_q = total_max_q / game_length\n",
    "            total_max_q = 0.0\n",
    "            game_length = 0\n",
    "\n",
    "        if iteration < training_start or iteration % training_interval != 0:\n",
    "            continue # only train after warmup period and at regular intervals\n",
    "        \n",
    "        # Sample memories and use the target DQN to produce the target Q-Value\n",
    "        X_state_val, X_action_val, rewards, X_next_state_val, continues = (\n",
    "            sample_memories(batch_size))\n",
    "        next_q_values = target_q_values.eval(\n",
    "            feed_dict={X_state: X_next_state_val})\n",
    "        max_next_q_values = np.max(next_q_values, axis=1, keepdims=True)\n",
    "        y_val = rewards + continues * discount_rate * max_next_q_values\n",
    "\n",
    "        # Train the online DQN\n",
    "        _, loss_val = sess.run([training_op, loss], feed_dict={\n",
    "            X_state: X_state_val, X_action: X_action_val, y: y_val})\n",
    "\n",
    "        # Regularly copy the online DQN to the target DQN\n",
    "        if step % copy_steps == 0:\n",
    "            copy_online_to_target.run()\n",
    "\n",
    "        # And save regularly\n",
    "        if step % save_steps == 0:\n",
    "            saver.save(sess, checkpoint_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can interrupt the cell above at any time to test your agent using the cell below. You can then run the cell above once again, it will load the last parameters saved and resume training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ./my_dqn.ckpt\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2017-09-25 13:53:39,307] Restoring parameters from ./my_dqn.ckpt\n"
     ]
    }
   ],
   "source": [
    "frames = []\n",
    "n_max_steps = 10000\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    saver.restore(sess, checkpoint_path)\n",
    "\n",
    "    obs = env.reset()\n",
    "    for step in range(n_max_steps):\n",
    "        state = preprocess_observation(obs)\n",
    "\n",
    "        # Online DQN evaluates what to do\n",
    "        q_values = online_q_values.eval(feed_dict={X_state: [state]})\n",
    "        action = np.argmax(q_values)\n",
    "\n",
    "        # Online DQN plays\n",
    "        obs, reward, done, info = env.step(action)\n",
    "\n",
    "        img = env.render(mode=\"rgb_array\")\n",
    "        frames.append(img)\n",
    "\n",
    "        if done:\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plot_animation(frames)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Extra material"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocessing for Breakout"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is a preprocessing function you can use to train a DQN for the Breakout-v0 Atari game:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def preprocess_observation(obs):\n",
    "    img = obs[34:194:2, ::2] # crop and downsize\n",
    "    return np.mean(img, axis=2).reshape(80, 80) / 255.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2017-09-25 13:54:27,989] Making new env: Breakout-v0\n"
     ]
    }
   ],
   "source": [
    "env = gym.make(\"Breakout-v0\")\n",
    "obs = env.reset()\n",
    "for step in range(10):\n",
    "    obs, _, _, _ = env.step(1)\n",
    "\n",
    "img = preprocess_observation(obs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "/* Put everything inside the global mpl namespace */\n",
       "window.mpl = {};\n",
       "\n",
       "\n",
       "mpl.get_websocket_type = function() {\n",
       "    if (typeof(WebSocket) !== 'undefined') {\n",
       "        return WebSocket;\n",
       "    } else if (typeof(MozWebSocket) !== 'undefined') {\n",
       "        return MozWebSocket;\n",
       "    } else {\n",
       "        alert('Your browser does not have WebSocket support.' +\n",
       "              'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
       "              'Firefox 4 and 5 are also supported but you ' +\n",
       "              'have to enable WebSockets in about:config.');\n",
       "    };\n",
       "}\n",
       "\n",
       "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
       "    this.id = figure_id;\n",
       "\n",
       "    this.ws = websocket;\n",
       "\n",
       "    this.supports_binary = (this.ws.binaryType != undefined);\n",
       "\n",
       "    if (!this.supports_binary) {\n",
       "        var warnings = document.getElementById(\"mpl-warnings\");\n",
       "        if (warnings) {\n",
       "            warnings.style.display = 'block';\n",
       "            warnings.textContent = (\n",
       "                \"This browser does not support binary websocket messages. \" +\n",
       "                    \"Performance may be slow.\");\n",
       "        }\n",
       "    }\n",
       "\n",
       "    this.imageObj = new Image();\n",
       "\n",
       "    this.context = undefined;\n",
       "    this.message = undefined;\n",
       "    this.canvas = undefined;\n",
       "    this.rubberband_canvas = undefined;\n",
       "    this.rubberband_context = undefined;\n",
       "    this.format_dropdown = undefined;\n",
       "\n",
       "    this.image_mode = 'full';\n",
       "\n",
       "    this.root = $('<div/>');\n",
       "    this._root_extra_style(this.root)\n",
       "    this.root.attr('style', 'display: inline-block');\n",
       "\n",
       "    $(parent_element).append(this.root);\n",
       "\n",
       "    this._init_header(this);\n",
       "    this._init_canvas(this);\n",
       "    this._init_toolbar(this);\n",
       "\n",
       "    var fig = this;\n",
       "\n",
       "    this.waiting = false;\n",
       "\n",
       "    this.ws.onopen =  function () {\n",
       "            fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
       "            fig.send_message(\"send_image_mode\", {});\n",
       "            if (mpl.ratio != 1) {\n",
       "                fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
       "            }\n",
       "            fig.send_message(\"refresh\", {});\n",
       "        }\n",
       "\n",
       "    this.imageObj.onload = function() {\n",
       "            if (fig.image_mode == 'full') {\n",
       "                // Full images could contain transparency (where diff images\n",
       "                // almost always do), so we need to clear the canvas so that\n",
       "                // there is no ghosting.\n",
       "                fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
       "            }\n",
       "            fig.context.drawImage(fig.imageObj, 0, 0);\n",
       "        };\n",
       "\n",
       "    this.imageObj.onunload = function() {\n",
       "        this.ws.close();\n",
       "    }\n",
       "\n",
       "    this.ws.onmessage = this._make_on_message_function(this);\n",
       "\n",
       "    this.ondownload = ondownload;\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._init_header = function() {\n",
       "    var titlebar = $(\n",
       "        '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
       "        'ui-helper-clearfix\"/>');\n",
       "    var titletext = $(\n",
       "        '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
       "        'text-align: center; padding: 3px;\"/>');\n",
       "    titlebar.append(titletext)\n",
       "    this.root.append(titlebar);\n",
       "    this.header = titletext[0];\n",
       "}\n",
       "\n",
       "\n",
       "\n",
       "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
       "\n",
       "}\n",
       "\n",
       "\n",
       "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
       "\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._init_canvas = function() {\n",
       "    var fig = this;\n",
       "\n",
       "    var canvas_div = $('<div/>');\n",
       "\n",
       "    canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
       "\n",
       "    function canvas_keyboard_event(event) {\n",
       "        return fig.key_event(event, event['data']);\n",
       "    }\n",
       "\n",
       "    canvas_div.keydown('key_press', canvas_keyboard_event);\n",
       "    canvas_div.keyup('key_release', canvas_keyboard_event);\n",
       "    this.canvas_div = canvas_div\n",
       "    this._canvas_extra_style(canvas_div)\n",
       "    this.root.append(canvas_div);\n",
       "\n",
       "    var canvas = $('<canvas/>');\n",
       "    canvas.addClass('mpl-canvas');\n",
       "    canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
       "\n",
       "    this.canvas = canvas[0];\n",
       "    this.context = canvas[0].getContext(\"2d\");\n",
       "\n",
       "    var backingStore = this.context.backingStorePixelRatio ||\n",
       "\tthis.context.webkitBackingStorePixelRatio ||\n",
       "\tthis.context.mozBackingStorePixelRatio ||\n",
       "\tthis.context.msBackingStorePixelRatio ||\n",
       "\tthis.context.oBackingStorePixelRatio ||\n",
       "\tthis.context.backingStorePixelRatio || 1;\n",
       "\n",
       "    mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
       "\n",
       "    var rubberband = $('<canvas/>');\n",
       "    rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
       "\n",
       "    var pass_mouse_events = true;\n",
       "\n",
       "    canvas_div.resizable({\n",
       "        start: function(event, ui) {\n",
       "            pass_mouse_events = false;\n",
       "        },\n",
       "        resize: function(event, ui) {\n",
       "            fig.request_resize(ui.size.width, ui.size.height);\n",
       "        },\n",
       "        stop: function(event, ui) {\n",
       "            pass_mouse_events = true;\n",
       "            fig.request_resize(ui.size.width, ui.size.height);\n",
       "        },\n",
       "    });\n",
       "\n",
       "    function mouse_event_fn(event) {\n",
       "        if (pass_mouse_events)\n",
       "            return fig.mouse_event(event, event['data']);\n",
       "    }\n",
       "\n",
       "    rubberband.mousedown('button_press', mouse_event_fn);\n",
       "    rubberband.mouseup('button_release', mouse_event_fn);\n",
       "    // Throttle sequential mouse events to 1 every 20ms.\n",
       "    rubberband.mousemove('motion_notify', mouse_event_fn);\n",
       "\n",
       "    rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
       "    rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
       "\n",
       "    canvas_div.on(\"wheel\", function (event) {\n",
       "        event = event.originalEvent;\n",
       "        event['data'] = 'scroll'\n",
       "        if (event.deltaY < 0) {\n",
       "            event.step = 1;\n",
       "        } else {\n",
       "            event.step = -1;\n",
       "        }\n",
       "        mouse_event_fn(event);\n",
       "    });\n",
       "\n",
       "    canvas_div.append(canvas);\n",
       "    canvas_div.append(rubberband);\n",
       "\n",
       "    this.rubberband = rubberband;\n",
       "    this.rubberband_canvas = rubberband[0];\n",
       "    this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
       "    this.rubberband_context.strokeStyle = \"#000000\";\n",
       "\n",
       "    this._resize_canvas = function(width, height) {\n",
       "        // Keep the size of the canvas, canvas container, and rubber band\n",
       "        // canvas in synch.\n",
       "        canvas_div.css('width', width)\n",
       "        canvas_div.css('height', height)\n",
       "\n",
       "        canvas.attr('width', width * mpl.ratio);\n",
       "        canvas.attr('height', height * mpl.ratio);\n",
       "        canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
       "\n",
       "        rubberband.attr('width', width);\n",
       "        rubberband.attr('height', height);\n",
       "    }\n",
       "\n",
       "    // Set the figure to an initial 600x600px, this will subsequently be updated\n",
       "    // upon first draw.\n",
       "    this._resize_canvas(600, 600);\n",
       "\n",
       "    // Disable right mouse context menu.\n",
       "    $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
       "        return false;\n",
       "    });\n",
       "\n",
       "    function set_focus () {\n",
       "        canvas.focus();\n",
       "        canvas_div.focus();\n",
       "    }\n",
       "\n",
       "    window.setTimeout(set_focus, 100);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._init_toolbar = function() {\n",
       "    var fig = this;\n",
       "\n",
       "    var nav_element = $('<div/>')\n",
       "    nav_element.attr('style', 'width: 100%');\n",
       "    this.root.append(nav_element);\n",
       "\n",
       "    // Define a callback function for later on.\n",
       "    function toolbar_event(event) {\n",
       "        return fig.toolbar_button_onclick(event['data']);\n",
       "    }\n",
       "    function toolbar_mouse_event(event) {\n",
       "        return fig.toolbar_button_onmouseover(event['data']);\n",
       "    }\n",
       "\n",
       "    for(var toolbar_ind in mpl.toolbar_items) {\n",
       "        var name = mpl.toolbar_items[toolbar_ind][0];\n",
       "        var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
       "        var image = mpl.toolbar_items[toolbar_ind][2];\n",
       "        var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
       "\n",
       "        if (!name) {\n",
       "            // put a spacer in here.\n",
       "            continue;\n",
       "        }\n",
       "        var button = $('<button/>');\n",
       "        button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
       "                        'ui-button-icon-only');\n",
       "        button.attr('role', 'button');\n",
       "        button.attr('aria-disabled', 'false');\n",
       "        button.click(method_name, toolbar_event);\n",
       "        button.mouseover(tooltip, toolbar_mouse_event);\n",
       "\n",
       "        var icon_img = $('<span/>');\n",
       "        icon_img.addClass('ui-button-icon-primary ui-icon');\n",
       "        icon_img.addClass(image);\n",
       "        icon_img.addClass('ui-corner-all');\n",
       "\n",
       "        var tooltip_span = $('<span/>');\n",
       "        tooltip_span.addClass('ui-button-text');\n",
       "        tooltip_span.html(tooltip);\n",
       "\n",
       "        button.append(icon_img);\n",
       "        button.append(tooltip_span);\n",
       "\n",
       "        nav_element.append(button);\n",
       "    }\n",
       "\n",
       "    var fmt_picker_span = $('<span/>');\n",
       "\n",
       "    var fmt_picker = $('<select/>');\n",
       "    fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
       "    fmt_picker_span.append(fmt_picker);\n",
       "    nav_element.append(fmt_picker_span);\n",
       "    this.format_dropdown = fmt_picker[0];\n",
       "\n",
       "    for (var ind in mpl.extensions) {\n",
       "        var fmt = mpl.extensions[ind];\n",
       "        var option = $(\n",
       "            '<option/>', {selected: fmt === mpl.default_extension}).html(fmt);\n",
       "        fmt_picker.append(option)\n",
       "    }\n",
       "\n",
       "    // Add hover states to the ui-buttons\n",
       "    $( \".ui-button\" ).hover(\n",
       "        function() { $(this).addClass(\"ui-state-hover\");},\n",
       "        function() { $(this).removeClass(\"ui-state-hover\");}\n",
       "    );\n",
       "\n",
       "    var status_bar = $('<span class=\"mpl-message\"/>');\n",
       "    nav_element.append(status_bar);\n",
       "    this.message = status_bar[0];\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
       "    // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
       "    // which will in turn request a refresh of the image.\n",
       "    this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.send_message = function(type, properties) {\n",
       "    properties['type'] = type;\n",
       "    properties['figure_id'] = this.id;\n",
       "    this.ws.send(JSON.stringify(properties));\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.send_draw_message = function() {\n",
       "    if (!this.waiting) {\n",
       "        this.waiting = true;\n",
       "        this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
       "    }\n",
       "}\n",
       "\n",
       "\n",
       "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
       "    var format_dropdown = fig.format_dropdown;\n",
       "    var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
       "    fig.ondownload(fig, format);\n",
       "}\n",
       "\n",
       "\n",
       "mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
       "    var size = msg['size'];\n",
       "    if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
       "        fig._resize_canvas(size[0], size[1]);\n",
       "        fig.send_message(\"refresh\", {});\n",
       "    };\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
       "    var x0 = msg['x0'] / mpl.ratio;\n",
       "    var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
       "    var x1 = msg['x1'] / mpl.ratio;\n",
       "    var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
       "    x0 = Math.floor(x0) + 0.5;\n",
       "    y0 = Math.floor(y0) + 0.5;\n",
       "    x1 = Math.floor(x1) + 0.5;\n",
       "    y1 = Math.floor(y1) + 0.5;\n",
       "    var min_x = Math.min(x0, x1);\n",
       "    var min_y = Math.min(y0, y1);\n",
       "    var width = Math.abs(x1 - x0);\n",
       "    var height = Math.abs(y1 - y0);\n",
       "\n",
       "    fig.rubberband_context.clearRect(\n",
       "        0, 0, fig.canvas.width, fig.canvas.height);\n",
       "\n",
       "    fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
       "    // Updates the figure title.\n",
       "    fig.header.textContent = msg['label'];\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
       "    var cursor = msg['cursor'];\n",
       "    switch(cursor)\n",
       "    {\n",
       "    case 0:\n",
       "        cursor = 'pointer';\n",
       "        break;\n",
       "    case 1:\n",
       "        cursor = 'default';\n",
       "        break;\n",
       "    case 2:\n",
       "        cursor = 'crosshair';\n",
       "        break;\n",
       "    case 3:\n",
       "        cursor = 'move';\n",
       "        break;\n",
       "    }\n",
       "    fig.rubberband_canvas.style.cursor = cursor;\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_message = function(fig, msg) {\n",
       "    fig.message.textContent = msg['message'];\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
       "    // Request the server to send over a new figure.\n",
       "    fig.send_draw_message();\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
       "    fig.image_mode = msg['mode'];\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.updated_canvas_event = function() {\n",
       "    // Called whenever the canvas gets updated.\n",
       "    this.send_message(\"ack\", {});\n",
       "}\n",
       "\n",
       "// A function to construct a web socket function for onmessage handling.\n",
       "// Called in the figure constructor.\n",
       "mpl.figure.prototype._make_on_message_function = function(fig) {\n",
       "    return function socket_on_message(evt) {\n",
       "        if (evt.data instanceof Blob) {\n",
       "            /* FIXME: We get \"Resource interpreted as Image but\n",
       "             * transferred with MIME type text/plain:\" errors on\n",
       "             * Chrome.  But how to set the MIME type?  It doesn't seem\n",
       "             * to be part of the websocket stream */\n",
       "            evt.data.type = \"image/png\";\n",
       "\n",
       "            /* Free the memory for the previous frames */\n",
       "            if (fig.imageObj.src) {\n",
       "                (window.URL || window.webkitURL).revokeObjectURL(\n",
       "                    fig.imageObj.src);\n",
       "            }\n",
       "\n",
       "            fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
       "                evt.data);\n",
       "            fig.updated_canvas_event();\n",
       "            fig.waiting = false;\n",
       "            return;\n",
       "        }\n",
       "        else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
       "            fig.imageObj.src = evt.data;\n",
       "            fig.updated_canvas_event();\n",
       "            fig.waiting = false;\n",
       "            return;\n",
       "        }\n",
       "\n",
       "        var msg = JSON.parse(evt.data);\n",
       "        var msg_type = msg['type'];\n",
       "\n",
       "        // Call the  \"handle_{type}\" callback, which takes\n",
       "        // the figure and JSON message as its only arguments.\n",
       "        try {\n",
       "            var callback = fig[\"handle_\" + msg_type];\n",
       "        } catch (e) {\n",
       "            console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
       "            return;\n",
       "        }\n",
       "\n",
       "        if (callback) {\n",
       "            try {\n",
       "                // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
       "                callback(fig, msg);\n",
       "            } catch (e) {\n",
       "                console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
       "            }\n",
       "        }\n",
       "    };\n",
       "}\n",
       "\n",
       "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
       "mpl.findpos = function(e) {\n",
       "    //this section is from http://www.quirksmode.org/js/events_properties.html\n",
       "    var targ;\n",
       "    if (!e)\n",
       "        e = window.event;\n",
       "    if (e.target)\n",
       "        targ = e.target;\n",
       "    else if (e.srcElement)\n",
       "        targ = e.srcElement;\n",
       "    if (targ.nodeType == 3) // defeat Safari bug\n",
       "        targ = targ.parentNode;\n",
       "\n",
       "    // jQuery normalizes the pageX and pageY\n",
       "    // pageX,Y are the mouse positions relative to the document\n",
       "    // offset() returns the position of the element relative to the document\n",
       "    var x = e.pageX - $(targ).offset().left;\n",
       "    var y = e.pageY - $(targ).offset().top;\n",
       "\n",
       "    return {\"x\": x, \"y\": y};\n",
       "};\n",
       "\n",
       "/*\n",
       " * return a copy of an object with only non-object keys\n",
       " * we need this to avoid circular references\n",
       " * http://stackoverflow.com/a/24161582/3208463\n",
       " */\n",
       "function simpleKeys (original) {\n",
       "  return Object.keys(original).reduce(function (obj, key) {\n",
       "    if (typeof original[key] !== 'object')\n",
       "        obj[key] = original[key]\n",
       "    return obj;\n",
       "  }, {});\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.mouse_event = function(event, name) {\n",
       "    var canvas_pos = mpl.findpos(event)\n",
       "\n",
       "    if (name === 'button_press')\n",
       "    {\n",
       "        this.canvas.focus();\n",
       "        this.canvas_div.focus();\n",
       "    }\n",
       "\n",
       "    var x = canvas_pos.x * mpl.ratio;\n",
       "    var y = canvas_pos.y * mpl.ratio;\n",
       "\n",
       "    this.send_message(name, {x: x, y: y, button: event.button,\n",
       "                             step: event.step,\n",
       "                             guiEvent: simpleKeys(event)});\n",
       "\n",
       "    /* This prevents the web browser from automatically changing to\n",
       "     * the text insertion cursor when the button is pressed.  We want\n",
       "     * to control all of the cursor setting manually through the\n",
       "     * 'cursor' event from matplotlib */\n",
       "    event.preventDefault();\n",
       "    return false;\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
       "    // Handle any extra behaviour associated with a key event\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.key_event = function(event, name) {\n",
       "\n",
       "    // Prevent repeat events\n",
       "    if (name == 'key_press')\n",
       "    {\n",
       "        if (event.which === this._key)\n",
       "            return;\n",
       "        else\n",
       "            this._key = event.which;\n",
       "    }\n",
       "    if (name == 'key_release')\n",
       "        this._key = null;\n",
       "\n",
       "    var value = '';\n",
       "    if (event.ctrlKey && event.which != 17)\n",
       "        value += \"ctrl+\";\n",
       "    if (event.altKey && event.which != 18)\n",
       "        value += \"alt+\";\n",
       "    if (event.shiftKey && event.which != 16)\n",
       "        value += \"shift+\";\n",
       "\n",
       "    value += 'k';\n",
       "    value += event.which.toString();\n",
       "\n",
       "    this._key_event_extra(event, name);\n",
       "\n",
       "    this.send_message(name, {key: value,\n",
       "                             guiEvent: simpleKeys(event)});\n",
       "    return false;\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
       "    if (name == 'download') {\n",
       "        this.handle_save(this, null);\n",
       "    } else {\n",
       "        this.send_message(\"toolbar_button\", {name: name});\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
       "    this.message.textContent = tooltip;\n",
       "};\n",
       "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to  previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
       "\n",
       "mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
       "\n",
       "mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
       "    // Create a \"websocket\"-like object which calls the given IPython comm\n",
       "    // object with the appropriate methods. Currently this is a non binary\n",
       "    // socket, so there is still some room for performance tuning.\n",
       "    var ws = {};\n",
       "\n",
       "    ws.close = function() {\n",
       "        comm.close()\n",
       "    };\n",
       "    ws.send = function(m) {\n",
       "        //console.log('sending', m);\n",
       "        comm.send(m);\n",
       "    };\n",
       "    // Register the callback with on_msg.\n",
       "    comm.on_msg(function(msg) {\n",
       "        //console.log('receiving', msg['content']['data'], msg);\n",
       "        // Pass the mpl event to the overriden (by mpl) onmessage function.\n",
       "        ws.onmessage(msg['content']['data'])\n",
       "    });\n",
       "    return ws;\n",
       "}\n",
       "\n",
       "mpl.mpl_figure_comm = function(comm, msg) {\n",
       "    // This is the function which gets called when the mpl process\n",
       "    // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
       "\n",
       "    var id = msg.content.data.id;\n",
       "    // Get hold of the div created by the display call when the Comm\n",
       "    // socket was opened in Python.\n",
       "    var element = $(\"#\" + id);\n",
       "    var ws_proxy = comm_websocket_adapter(comm)\n",
       "\n",
       "    function ondownload(figure, format) {\n",
       "        window.open(figure.imageObj.src);\n",
       "    }\n",
       "\n",
       "    var fig = new mpl.figure(id, ws_proxy,\n",
       "                           ondownload,\n",
       "                           element.get(0));\n",
       "\n",
       "    // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
       "    // web socket which is closed, not our websocket->open comm proxy.\n",
       "    ws_proxy.onopen();\n",
       "\n",
       "    fig.parent_element = element.get(0);\n",
       "    fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
       "    if (!fig.cell_info) {\n",
       "        console.error(\"Failed to find cell for figure\", id, fig);\n",
       "        return;\n",
       "    }\n",
       "\n",
       "    var output_index = fig.cell_info[2]\n",
       "    var cell = fig.cell_info[0];\n",
       "\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_close = function(fig, msg) {\n",
       "    var width = fig.canvas.width/mpl.ratio\n",
       "    fig.root.unbind('remove')\n",
       "\n",
       "    // Update the output cell to use the data from the current canvas.\n",
       "    fig.push_to_output();\n",
       "    var dataURL = fig.canvas.toDataURL();\n",
       "    // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
       "    // the notebook keyboard shortcuts fail.\n",
       "    IPython.keyboard_manager.enable()\n",
       "    $(fig.parent_element).html('<img src=\"' + dataURL + '\" width=\"' + width + '\">');\n",
       "    fig.close_ws(fig, msg);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.close_ws = function(fig, msg){\n",
       "    fig.send_message('closing', msg);\n",
       "    // fig.ws.close()\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
       "    // Turn the data on the canvas into data in the output cell.\n",
       "    var width = this.canvas.width/mpl.ratio\n",
       "    var dataURL = this.canvas.toDataURL();\n",
       "    this.cell_info[1]['text/html'] = '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.updated_canvas_event = function() {\n",
       "    // Tell IPython that the notebook contents must change.\n",
       "    IPython.notebook.set_dirty(true);\n",
       "    this.send_message(\"ack\", {});\n",
       "    var fig = this;\n",
       "    // Wait a second, then push the new image to the DOM so\n",
       "    // that it is saved nicely (might be nice to debounce this).\n",
       "    setTimeout(function () { fig.push_to_output() }, 1000);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._init_toolbar = function() {\n",
       "    var fig = this;\n",
       "\n",
       "    var nav_element = $('<div/>')\n",
       "    nav_element.attr('style', 'width: 100%');\n",
       "    this.root.append(nav_element);\n",
       "\n",
       "    // Define a callback function for later on.\n",
       "    function toolbar_event(event) {\n",
       "        return fig.toolbar_button_onclick(event['data']);\n",
       "    }\n",
       "    function toolbar_mouse_event(event) {\n",
       "        return fig.toolbar_button_onmouseover(event['data']);\n",
       "    }\n",
       "\n",
       "    for(var toolbar_ind in mpl.toolbar_items){\n",
       "        var name = mpl.toolbar_items[toolbar_ind][0];\n",
       "        var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
       "        var image = mpl.toolbar_items[toolbar_ind][2];\n",
       "        var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
       "\n",
       "        if (!name) { continue; };\n",
       "\n",
       "        var button = $('<button class=\"btn btn-default\" href=\"#\" title=\"' + name + '\"><i class=\"fa ' + image + ' fa-lg\"></i></button>');\n",
       "        button.click(method_name, toolbar_event);\n",
       "        button.mouseover(tooltip, toolbar_mouse_event);\n",
       "        nav_element.append(button);\n",
       "    }\n",
       "\n",
       "    // Add the status bar.\n",
       "    var status_bar = $('<span class=\"mpl-message\" style=\"text-align:right; float: right;\"/>');\n",
       "    nav_element.append(status_bar);\n",
       "    this.message = status_bar[0];\n",
       "\n",
       "    // Add the close button to the window.\n",
       "    var buttongrp = $('<div class=\"btn-group inline pull-right\"></div>');\n",
       "    var button = $('<button class=\"btn btn-mini btn-primary\" href=\"#\" title=\"Stop Interaction\"><i class=\"fa fa-power-off icon-remove icon-large\"></i></button>');\n",
       "    button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
       "    button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
       "    buttongrp.append(button);\n",
       "    var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
       "    titlebar.prepend(buttongrp);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._root_extra_style = function(el){\n",
       "    var fig = this\n",
       "    el.on(\"remove\", function(){\n",
       "\tfig.close_ws(fig, {});\n",
       "    });\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._canvas_extra_style = function(el){\n",
       "    // this is important to make the div 'focusable\n",
       "    el.attr('tabindex', 0)\n",
       "    // reach out to IPython and tell the keyboard manager to turn it's self\n",
       "    // off when our div gets focus\n",
       "\n",
       "    // location in version 3\n",
       "    if (IPython.notebook.keyboard_manager) {\n",
       "        IPython.notebook.keyboard_manager.register_events(el);\n",
       "    }\n",
       "    else {\n",
       "        // location in version 2\n",
       "        IPython.keyboard_manager.register_events(el);\n",
       "    }\n",
       "\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
       "    var manager = IPython.notebook.keyboard_manager;\n",
       "    if (!manager)\n",
       "        manager = IPython.keyboard_manager;\n",
       "\n",
       "    // Check for shift+enter\n",
       "    if (event.shiftKey && event.which == 13) {\n",
       "        this.canvas_div.blur();\n",
       "        // select the cell after this one\n",
       "        var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n",
       "        IPython.notebook.select(index + 1);\n",
       "    }\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
       "    fig.ondownload(fig, null);\n",
       "}\n",
       "\n",
       "\n",
       "mpl.find_output_cell = function(html_output) {\n",
       "    // Return the cell and output element which can be found *uniquely* in the notebook.\n",
       "    // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
       "    // IPython event is triggered only after the cells have been serialised, which for\n",
       "    // our purposes (turning an active figure into a static one), is too late.\n",
       "    var cells = IPython.notebook.get_cells();\n",
       "    var ncells = cells.length;\n",
       "    for (var i=0; i<ncells; i++) {\n",
       "        var cell = cells[i];\n",
       "        if (cell.cell_type === 'code'){\n",
       "            for (var j=0; j<cell.output_area.outputs.length; j++) {\n",
       "                var data = cell.output_area.outputs[j];\n",
       "                if (data.data) {\n",
       "                    // IPython >= 3 moved mimebundle to data attribute of output\n",
       "                    data = data.data;\n",
       "                }\n",
       "                if (data['text/html'] == html_output) {\n",
       "                    return [cell, data, j];\n",
       "                }\n",
       "            }\n",
       "        }\n",
       "    }\n",
       "}\n",
       "\n",
       "// Register the function which deals with the matplotlib target/channel.\n",
       "// The kernel may be null if the page has been refreshed.\n",
       "if (IPython.notebook.kernel != null) {\n",
       "    IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
       "}\n"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<img src=\"\" width=\"1100\">"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(11, 7))\n",
    "plt.subplot(121)\n",
    "plt.title(\"Original observation (160×210 RGB)\")\n",
    "plt.imshow(obs)\n",
    "plt.axis(\"off\")\n",
    "plt.subplot(122)\n",
    "plt.title(\"Preprocessed observation (80×80 grayscale)\")\n",
    "plt.imshow(img, interpolation=\"nearest\", cmap=\"gray\")\n",
    "plt.axis(\"off\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, a single image does not give you the direction and speed of the ball, which are crucial informations for playing this game. For this reason, it is best to actually combine several consecutive observations to create the environment's state representation. One way to do that is to create a multi-channel image, with one channel per recent observation. Another is to merge all recent observations into a single-channel image, using `np.max()`. In this case, we need to dim the older images so that the DQN can distinguish the past from the present."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import deque\n",
    "\n",
    "def combine_observations_multichannel(preprocessed_observations):\n",
    "    return np.array(preprocessed_observations).transpose([1, 2, 0])\n",
    "\n",
    "def combine_observations_singlechannel(preprocessed_observations, dim_factor=0.5):\n",
    "    dimmed_observations = [obs * dim_factor**index\n",
    "                           for index, obs in enumerate(reversed(preprocessed_observations))]\n",
    "    return np.max(np.array(dimmed_observations), axis=0)\n",
    "\n",
    "n_observations_per_state = 3\n",
    "preprocessed_observations = deque([], maxlen=n_observations_per_state)\n",
    "\n",
    "obs = env.reset()\n",
    "for step in range(10):\n",
    "    obs, _, _, _ = env.step(1)\n",
    "    preprocessed_observations.append(preprocess_observation(obs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "/* Put everything inside the global mpl namespace */\n",
       "window.mpl = {};\n",
       "\n",
       "\n",
       "mpl.get_websocket_type = function() {\n",
       "    if (typeof(WebSocket) !== 'undefined') {\n",
       "        return WebSocket;\n",
       "    } else if (typeof(MozWebSocket) !== 'undefined') {\n",
       "        return MozWebSocket;\n",
       "    } else {\n",
       "        alert('Your browser does not have WebSocket support.' +\n",
       "              'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
       "              'Firefox 4 and 5 are also supported but you ' +\n",
       "              'have to enable WebSockets in about:config.');\n",
       "    };\n",
       "}\n",
       "\n",
       "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
       "    this.id = figure_id;\n",
       "\n",
       "    this.ws = websocket;\n",
       "\n",
       "    this.supports_binary = (this.ws.binaryType != undefined);\n",
       "\n",
       "    if (!this.supports_binary) {\n",
       "        var warnings = document.getElementById(\"mpl-warnings\");\n",
       "        if (warnings) {\n",
       "            warnings.style.display = 'block';\n",
       "            warnings.textContent = (\n",
       "                \"This browser does not support binary websocket messages. \" +\n",
       "                    \"Performance may be slow.\");\n",
       "        }\n",
       "    }\n",
       "\n",
       "    this.imageObj = new Image();\n",
       "\n",
       "    this.context = undefined;\n",
       "    this.message = undefined;\n",
       "    this.canvas = undefined;\n",
       "    this.rubberband_canvas = undefined;\n",
       "    this.rubberband_context = undefined;\n",
       "    this.format_dropdown = undefined;\n",
       "\n",
       "    this.image_mode = 'full';\n",
       "\n",
       "    this.root = $('<div/>');\n",
       "    this._root_extra_style(this.root)\n",
       "    this.root.attr('style', 'display: inline-block');\n",
       "\n",
       "    $(parent_element).append(this.root);\n",
       "\n",
       "    this._init_header(this);\n",
       "    this._init_canvas(this);\n",
       "    this._init_toolbar(this);\n",
       "\n",
       "    var fig = this;\n",
       "\n",
       "    this.waiting = false;\n",
       "\n",
       "    this.ws.onopen =  function () {\n",
       "            fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
       "            fig.send_message(\"send_image_mode\", {});\n",
       "            if (mpl.ratio != 1) {\n",
       "                fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
       "            }\n",
       "            fig.send_message(\"refresh\", {});\n",
       "        }\n",
       "\n",
       "    this.imageObj.onload = function() {\n",
       "            if (fig.image_mode == 'full') {\n",
       "                // Full images could contain transparency (where diff images\n",
       "                // almost always do), so we need to clear the canvas so that\n",
       "                // there is no ghosting.\n",
       "                fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
       "            }\n",
       "            fig.context.drawImage(fig.imageObj, 0, 0);\n",
       "        };\n",
       "\n",
       "    this.imageObj.onunload = function() {\n",
       "        this.ws.close();\n",
       "    }\n",
       "\n",
       "    this.ws.onmessage = this._make_on_message_function(this);\n",
       "\n",
       "    this.ondownload = ondownload;\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._init_header = function() {\n",
       "    var titlebar = $(\n",
       "        '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
       "        'ui-helper-clearfix\"/>');\n",
       "    var titletext = $(\n",
       "        '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
       "        'text-align: center; padding: 3px;\"/>');\n",
       "    titlebar.append(titletext)\n",
       "    this.root.append(titlebar);\n",
       "    this.header = titletext[0];\n",
       "}\n",
       "\n",
       "\n",
       "\n",
       "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
       "\n",
       "}\n",
       "\n",
       "\n",
       "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
       "\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._init_canvas = function() {\n",
       "    var fig = this;\n",
       "\n",
       "    var canvas_div = $('<div/>');\n",
       "\n",
       "    canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
       "\n",
       "    function canvas_keyboard_event(event) {\n",
       "        return fig.key_event(event, event['data']);\n",
       "    }\n",
       "\n",
       "    canvas_div.keydown('key_press', canvas_keyboard_event);\n",
       "    canvas_div.keyup('key_release', canvas_keyboard_event);\n",
       "    this.canvas_div = canvas_div\n",
       "    this._canvas_extra_style(canvas_div)\n",
       "    this.root.append(canvas_div);\n",
       "\n",
       "    var canvas = $('<canvas/>');\n",
       "    canvas.addClass('mpl-canvas');\n",
       "    canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
       "\n",
       "    this.canvas = canvas[0];\n",
       "    this.context = canvas[0].getContext(\"2d\");\n",
       "\n",
       "    var backingStore = this.context.backingStorePixelRatio ||\n",
       "\tthis.context.webkitBackingStorePixelRatio ||\n",
       "\tthis.context.mozBackingStorePixelRatio ||\n",
       "\tthis.context.msBackingStorePixelRatio ||\n",
       "\tthis.context.oBackingStorePixelRatio ||\n",
       "\tthis.context.backingStorePixelRatio || 1;\n",
       "\n",
       "    mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
       "\n",
       "    var rubberband = $('<canvas/>');\n",
       "    rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
       "\n",
       "    var pass_mouse_events = true;\n",
       "\n",
       "    canvas_div.resizable({\n",
       "        start: function(event, ui) {\n",
       "            pass_mouse_events = false;\n",
       "        },\n",
       "        resize: function(event, ui) {\n",
       "            fig.request_resize(ui.size.width, ui.size.height);\n",
       "        },\n",
       "        stop: function(event, ui) {\n",
       "            pass_mouse_events = true;\n",
       "            fig.request_resize(ui.size.width, ui.size.height);\n",
       "        },\n",
       "    });\n",
       "\n",
       "    function mouse_event_fn(event) {\n",
       "        if (pass_mouse_events)\n",
       "            return fig.mouse_event(event, event['data']);\n",
       "    }\n",
       "\n",
       "    rubberband.mousedown('button_press', mouse_event_fn);\n",
       "    rubberband.mouseup('button_release', mouse_event_fn);\n",
       "    // Throttle sequential mouse events to 1 every 20ms.\n",
       "    rubberband.mousemove('motion_notify', mouse_event_fn);\n",
       "\n",
       "    rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
       "    rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
       "\n",
       "    canvas_div.on(\"wheel\", function (event) {\n",
       "        event = event.originalEvent;\n",
       "        event['data'] = 'scroll'\n",
       "        if (event.deltaY < 0) {\n",
       "            event.step = 1;\n",
       "        } else {\n",
       "            event.step = -1;\n",
       "        }\n",
       "        mouse_event_fn(event);\n",
       "    });\n",
       "\n",
       "    canvas_div.append(canvas);\n",
       "    canvas_div.append(rubberband);\n",
       "\n",
       "    this.rubberband = rubberband;\n",
       "    this.rubberband_canvas = rubberband[0];\n",
       "    this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
       "    this.rubberband_context.strokeStyle = \"#000000\";\n",
       "\n",
       "    this._resize_canvas = function(width, height) {\n",
       "        // Keep the size of the canvas, canvas container, and rubber band\n",
       "        // canvas in synch.\n",
       "        canvas_div.css('width', width)\n",
       "        canvas_div.css('height', height)\n",
       "\n",
       "        canvas.attr('width', width * mpl.ratio);\n",
       "        canvas.attr('height', height * mpl.ratio);\n",
       "        canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
       "\n",
       "        rubberband.attr('width', width);\n",
       "        rubberband.attr('height', height);\n",
       "    }\n",
       "\n",
       "    // Set the figure to an initial 600x600px, this will subsequently be updated\n",
       "    // upon first draw.\n",
       "    this._resize_canvas(600, 600);\n",
       "\n",
       "    // Disable right mouse context menu.\n",
       "    $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
       "        return false;\n",
       "    });\n",
       "\n",
       "    function set_focus () {\n",
       "        canvas.focus();\n",
       "        canvas_div.focus();\n",
       "    }\n",
       "\n",
       "    window.setTimeout(set_focus, 100);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._init_toolbar = function() {\n",
       "    var fig = this;\n",
       "\n",
       "    var nav_element = $('<div/>')\n",
       "    nav_element.attr('style', 'width: 100%');\n",
       "    this.root.append(nav_element);\n",
       "\n",
       "    // Define a callback function for later on.\n",
       "    function toolbar_event(event) {\n",
       "        return fig.toolbar_button_onclick(event['data']);\n",
       "    }\n",
       "    function toolbar_mouse_event(event) {\n",
       "        return fig.toolbar_button_onmouseover(event['data']);\n",
       "    }\n",
       "\n",
       "    for(var toolbar_ind in mpl.toolbar_items) {\n",
       "        var name = mpl.toolbar_items[toolbar_ind][0];\n",
       "        var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
       "        var image = mpl.toolbar_items[toolbar_ind][2];\n",
       "        var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
       "\n",
       "        if (!name) {\n",
       "            // put a spacer in here.\n",
       "            continue;\n",
       "        }\n",
       "        var button = $('<button/>');\n",
       "        button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
       "                        'ui-button-icon-only');\n",
       "        button.attr('role', 'button');\n",
       "        button.attr('aria-disabled', 'false');\n",
       "        button.click(method_name, toolbar_event);\n",
       "        button.mouseover(tooltip, toolbar_mouse_event);\n",
       "\n",
       "        var icon_img = $('<span/>');\n",
       "        icon_img.addClass('ui-button-icon-primary ui-icon');\n",
       "        icon_img.addClass(image);\n",
       "        icon_img.addClass('ui-corner-all');\n",
       "\n",
       "        var tooltip_span = $('<span/>');\n",
       "        tooltip_span.addClass('ui-button-text');\n",
       "        tooltip_span.html(tooltip);\n",
       "\n",
       "        button.append(icon_img);\n",
       "        button.append(tooltip_span);\n",
       "\n",
       "        nav_element.append(button);\n",
       "    }\n",
       "\n",
       "    var fmt_picker_span = $('<span/>');\n",
       "\n",
       "    var fmt_picker = $('<select/>');\n",
       "    fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
       "    fmt_picker_span.append(fmt_picker);\n",
       "    nav_element.append(fmt_picker_span);\n",
       "    this.format_dropdown = fmt_picker[0];\n",
       "\n",
       "    for (var ind in mpl.extensions) {\n",
       "        var fmt = mpl.extensions[ind];\n",
       "        var option = $(\n",
       "            '<option/>', {selected: fmt === mpl.default_extension}).html(fmt);\n",
       "        fmt_picker.append(option)\n",
       "    }\n",
       "\n",
       "    // Add hover states to the ui-buttons\n",
       "    $( \".ui-button\" ).hover(\n",
       "        function() { $(this).addClass(\"ui-state-hover\");},\n",
       "        function() { $(this).removeClass(\"ui-state-hover\");}\n",
       "    );\n",
       "\n",
       "    var status_bar = $('<span class=\"mpl-message\"/>');\n",
       "    nav_element.append(status_bar);\n",
       "    this.message = status_bar[0];\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
       "    // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
       "    // which will in turn request a refresh of the image.\n",
       "    this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.send_message = function(type, properties) {\n",
       "    properties['type'] = type;\n",
       "    properties['figure_id'] = this.id;\n",
       "    this.ws.send(JSON.stringify(properties));\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.send_draw_message = function() {\n",
       "    if (!this.waiting) {\n",
       "        this.waiting = true;\n",
       "        this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
       "    }\n",
       "}\n",
       "\n",
       "\n",
       "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
       "    var format_dropdown = fig.format_dropdown;\n",
       "    var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
       "    fig.ondownload(fig, format);\n",
       "}\n",
       "\n",
       "\n",
       "mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
       "    var size = msg['size'];\n",
       "    if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
       "        fig._resize_canvas(size[0], size[1]);\n",
       "        fig.send_message(\"refresh\", {});\n",
       "    };\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
       "    var x0 = msg['x0'] / mpl.ratio;\n",
       "    var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
       "    var x1 = msg['x1'] / mpl.ratio;\n",
       "    var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
       "    x0 = Math.floor(x0) + 0.5;\n",
       "    y0 = Math.floor(y0) + 0.5;\n",
       "    x1 = Math.floor(x1) + 0.5;\n",
       "    y1 = Math.floor(y1) + 0.5;\n",
       "    var min_x = Math.min(x0, x1);\n",
       "    var min_y = Math.min(y0, y1);\n",
       "    var width = Math.abs(x1 - x0);\n",
       "    var height = Math.abs(y1 - y0);\n",
       "\n",
       "    fig.rubberband_context.clearRect(\n",
       "        0, 0, fig.canvas.width, fig.canvas.height);\n",
       "\n",
       "    fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
       "    // Updates the figure title.\n",
       "    fig.header.textContent = msg['label'];\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
       "    var cursor = msg['cursor'];\n",
       "    switch(cursor)\n",
       "    {\n",
       "    case 0:\n",
       "        cursor = 'pointer';\n",
       "        break;\n",
       "    case 1:\n",
       "        cursor = 'default';\n",
       "        break;\n",
       "    case 2:\n",
       "        cursor = 'crosshair';\n",
       "        break;\n",
       "    case 3:\n",
       "        cursor = 'move';\n",
       "        break;\n",
       "    }\n",
       "    fig.rubberband_canvas.style.cursor = cursor;\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_message = function(fig, msg) {\n",
       "    fig.message.textContent = msg['message'];\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
       "    // Request the server to send over a new figure.\n",
       "    fig.send_draw_message();\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
       "    fig.image_mode = msg['mode'];\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.updated_canvas_event = function() {\n",
       "    // Called whenever the canvas gets updated.\n",
       "    this.send_message(\"ack\", {});\n",
       "}\n",
       "\n",
       "// A function to construct a web socket function for onmessage handling.\n",
       "// Called in the figure constructor.\n",
       "mpl.figure.prototype._make_on_message_function = function(fig) {\n",
       "    return function socket_on_message(evt) {\n",
       "        if (evt.data instanceof Blob) {\n",
       "            /* FIXME: We get \"Resource interpreted as Image but\n",
       "             * transferred with MIME type text/plain:\" errors on\n",
       "             * Chrome.  But how to set the MIME type?  It doesn't seem\n",
       "             * to be part of the websocket stream */\n",
       "            evt.data.type = \"image/png\";\n",
       "\n",
       "            /* Free the memory for the previous frames */\n",
       "            if (fig.imageObj.src) {\n",
       "                (window.URL || window.webkitURL).revokeObjectURL(\n",
       "                    fig.imageObj.src);\n",
       "            }\n",
       "\n",
       "            fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
       "                evt.data);\n",
       "            fig.updated_canvas_event();\n",
       "            fig.waiting = false;\n",
       "            return;\n",
       "        }\n",
       "        else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
       "            fig.imageObj.src = evt.data;\n",
       "            fig.updated_canvas_event();\n",
       "            fig.waiting = false;\n",
       "            return;\n",
       "        }\n",
       "\n",
       "        var msg = JSON.parse(evt.data);\n",
       "        var msg_type = msg['type'];\n",
       "\n",
       "        // Call the  \"handle_{type}\" callback, which takes\n",
       "        // the figure and JSON message as its only arguments.\n",
       "        try {\n",
       "            var callback = fig[\"handle_\" + msg_type];\n",
       "        } catch (e) {\n",
       "            console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
       "            return;\n",
       "        }\n",
       "\n",
       "        if (callback) {\n",
       "            try {\n",
       "                // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
       "                callback(fig, msg);\n",
       "            } catch (e) {\n",
       "                console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
       "            }\n",
       "        }\n",
       "    };\n",
       "}\n",
       "\n",
       "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
       "mpl.findpos = function(e) {\n",
       "    //this section is from http://www.quirksmode.org/js/events_properties.html\n",
       "    var targ;\n",
       "    if (!e)\n",
       "        e = window.event;\n",
       "    if (e.target)\n",
       "        targ = e.target;\n",
       "    else if (e.srcElement)\n",
       "        targ = e.srcElement;\n",
       "    if (targ.nodeType == 3) // defeat Safari bug\n",
       "        targ = targ.parentNode;\n",
       "\n",
       "    // jQuery normalizes the pageX and pageY\n",
       "    // pageX,Y are the mouse positions relative to the document\n",
       "    // offset() returns the position of the element relative to the document\n",
       "    var x = e.pageX - $(targ).offset().left;\n",
       "    var y = e.pageY - $(targ).offset().top;\n",
       "\n",
       "    return {\"x\": x, \"y\": y};\n",
       "};\n",
       "\n",
       "/*\n",
       " * return a copy of an object with only non-object keys\n",
       " * we need this to avoid circular references\n",
       " * http://stackoverflow.com/a/24161582/3208463\n",
       " */\n",
       "function simpleKeys (original) {\n",
       "  return Object.keys(original).reduce(function (obj, key) {\n",
       "    if (typeof original[key] !== 'object')\n",
       "        obj[key] = original[key]\n",
       "    return obj;\n",
       "  }, {});\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.mouse_event = function(event, name) {\n",
       "    var canvas_pos = mpl.findpos(event)\n",
       "\n",
       "    if (name === 'button_press')\n",
       "    {\n",
       "        this.canvas.focus();\n",
       "        this.canvas_div.focus();\n",
       "    }\n",
       "\n",
       "    var x = canvas_pos.x * mpl.ratio;\n",
       "    var y = canvas_pos.y * mpl.ratio;\n",
       "\n",
       "    this.send_message(name, {x: x, y: y, button: event.button,\n",
       "                             step: event.step,\n",
       "                             guiEvent: simpleKeys(event)});\n",
       "\n",
       "    /* This prevents the web browser from automatically changing to\n",
       "     * the text insertion cursor when the button is pressed.  We want\n",
       "     * to control all of the cursor setting manually through the\n",
       "     * 'cursor' event from matplotlib */\n",
       "    event.preventDefault();\n",
       "    return false;\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
       "    // Handle any extra behaviour associated with a key event\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.key_event = function(event, name) {\n",
       "\n",
       "    // Prevent repeat events\n",
       "    if (name == 'key_press')\n",
       "    {\n",
       "        if (event.which === this._key)\n",
       "            return;\n",
       "        else\n",
       "            this._key = event.which;\n",
       "    }\n",
       "    if (name == 'key_release')\n",
       "        this._key = null;\n",
       "\n",
       "    var value = '';\n",
       "    if (event.ctrlKey && event.which != 17)\n",
       "        value += \"ctrl+\";\n",
       "    if (event.altKey && event.which != 18)\n",
       "        value += \"alt+\";\n",
       "    if (event.shiftKey && event.which != 16)\n",
       "        value += \"shift+\";\n",
       "\n",
       "    value += 'k';\n",
       "    value += event.which.toString();\n",
       "\n",
       "    this._key_event_extra(event, name);\n",
       "\n",
       "    this.send_message(name, {key: value,\n",
       "                             guiEvent: simpleKeys(event)});\n",
       "    return false;\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
       "    if (name == 'download') {\n",
       "        this.handle_save(this, null);\n",
       "    } else {\n",
       "        this.send_message(\"toolbar_button\", {name: name});\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
       "    this.message.textContent = tooltip;\n",
       "};\n",
       "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to  previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
       "\n",
       "mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
       "\n",
       "mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
       "    // Create a \"websocket\"-like object which calls the given IPython comm\n",
       "    // object with the appropriate methods. Currently this is a non binary\n",
       "    // socket, so there is still some room for performance tuning.\n",
       "    var ws = {};\n",
       "\n",
       "    ws.close = function() {\n",
       "        comm.close()\n",
       "    };\n",
       "    ws.send = function(m) {\n",
       "        //console.log('sending', m);\n",
       "        comm.send(m);\n",
       "    };\n",
       "    // Register the callback with on_msg.\n",
       "    comm.on_msg(function(msg) {\n",
       "        //console.log('receiving', msg['content']['data'], msg);\n",
       "        // Pass the mpl event to the overriden (by mpl) onmessage function.\n",
       "        ws.onmessage(msg['content']['data'])\n",
       "    });\n",
       "    return ws;\n",
       "}\n",
       "\n",
       "mpl.mpl_figure_comm = function(comm, msg) {\n",
       "    // This is the function which gets called when the mpl process\n",
       "    // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
       "\n",
       "    var id = msg.content.data.id;\n",
       "    // Get hold of the div created by the display call when the Comm\n",
       "    // socket was opened in Python.\n",
       "    var element = $(\"#\" + id);\n",
       "    var ws_proxy = comm_websocket_adapter(comm)\n",
       "\n",
       "    function ondownload(figure, format) {\n",
       "        window.open(figure.imageObj.src);\n",
       "    }\n",
       "\n",
       "    var fig = new mpl.figure(id, ws_proxy,\n",
       "                           ondownload,\n",
       "                           element.get(0));\n",
       "\n",
       "    // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
       "    // web socket which is closed, not our websocket->open comm proxy.\n",
       "    ws_proxy.onopen();\n",
       "\n",
       "    fig.parent_element = element.get(0);\n",
       "    fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
       "    if (!fig.cell_info) {\n",
       "        console.error(\"Failed to find cell for figure\", id, fig);\n",
       "        return;\n",
       "    }\n",
       "\n",
       "    var output_index = fig.cell_info[2]\n",
       "    var cell = fig.cell_info[0];\n",
       "\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_close = function(fig, msg) {\n",
       "    var width = fig.canvas.width/mpl.ratio\n",
       "    fig.root.unbind('remove')\n",
       "\n",
       "    // Update the output cell to use the data from the current canvas.\n",
       "    fig.push_to_output();\n",
       "    var dataURL = fig.canvas.toDataURL();\n",
       "    // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
       "    // the notebook keyboard shortcuts fail.\n",
       "    IPython.keyboard_manager.enable()\n",
       "    $(fig.parent_element).html('<img src=\"' + dataURL + '\" width=\"' + width + '\">');\n",
       "    fig.close_ws(fig, msg);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.close_ws = function(fig, msg){\n",
       "    fig.send_message('closing', msg);\n",
       "    // fig.ws.close()\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
       "    // Turn the data on the canvas into data in the output cell.\n",
       "    var width = this.canvas.width/mpl.ratio\n",
       "    var dataURL = this.canvas.toDataURL();\n",
       "    this.cell_info[1]['text/html'] = '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.updated_canvas_event = function() {\n",
       "    // Tell IPython that the notebook contents must change.\n",
       "    IPython.notebook.set_dirty(true);\n",
       "    this.send_message(\"ack\", {});\n",
       "    var fig = this;\n",
       "    // Wait a second, then push the new image to the DOM so\n",
       "    // that it is saved nicely (might be nice to debounce this).\n",
       "    setTimeout(function () { fig.push_to_output() }, 1000);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._init_toolbar = function() {\n",
       "    var fig = this;\n",
       "\n",
       "    var nav_element = $('<div/>')\n",
       "    nav_element.attr('style', 'width: 100%');\n",
       "    this.root.append(nav_element);\n",
       "\n",
       "    // Define a callback function for later on.\n",
       "    function toolbar_event(event) {\n",
       "        return fig.toolbar_button_onclick(event['data']);\n",
       "    }\n",
       "    function toolbar_mouse_event(event) {\n",
       "        return fig.toolbar_button_onmouseover(event['data']);\n",
       "    }\n",
       "\n",
       "    for(var toolbar_ind in mpl.toolbar_items){\n",
       "        var name = mpl.toolbar_items[toolbar_ind][0];\n",
       "        var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
       "        var image = mpl.toolbar_items[toolbar_ind][2];\n",
       "        var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
       "\n",
       "        if (!name) { continue; };\n",
       "\n",
       "        var button = $('<button class=\"btn btn-default\" href=\"#\" title=\"' + name + '\"><i class=\"fa ' + image + ' fa-lg\"></i></button>');\n",
       "        button.click(method_name, toolbar_event);\n",
       "        button.mouseover(tooltip, toolbar_mouse_event);\n",
       "        nav_element.append(button);\n",
       "    }\n",
       "\n",
       "    // Add the status bar.\n",
       "    var status_bar = $('<span class=\"mpl-message\" style=\"text-align:right; float: right;\"/>');\n",
       "    nav_element.append(status_bar);\n",
       "    this.message = status_bar[0];\n",
       "\n",
       "    // Add the close button to the window.\n",
       "    var buttongrp = $('<div class=\"btn-group inline pull-right\"></div>');\n",
       "    var button = $('<button class=\"btn btn-mini btn-primary\" href=\"#\" title=\"Stop Interaction\"><i class=\"fa fa-power-off icon-remove icon-large\"></i></button>');\n",
       "    button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
       "    button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
       "    buttongrp.append(button);\n",
       "    var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
       "    titlebar.prepend(buttongrp);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._root_extra_style = function(el){\n",
       "    var fig = this\n",
       "    el.on(\"remove\", function(){\n",
       "\tfig.close_ws(fig, {});\n",
       "    });\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._canvas_extra_style = function(el){\n",
       "    // this is important to make the div 'focusable\n",
       "    el.attr('tabindex', 0)\n",
       "    // reach out to IPython and tell the keyboard manager to turn it's self\n",
       "    // off when our div gets focus\n",
       "\n",
       "    // location in version 3\n",
       "    if (IPython.notebook.keyboard_manager) {\n",
       "        IPython.notebook.keyboard_manager.register_events(el);\n",
       "    }\n",
       "    else {\n",
       "        // location in version 2\n",
       "        IPython.keyboard_manager.register_events(el);\n",
       "    }\n",
       "\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
       "    var manager = IPython.notebook.keyboard_manager;\n",
       "    if (!manager)\n",
       "        manager = IPython.keyboard_manager;\n",
       "\n",
       "    // Check for shift+enter\n",
       "    if (event.shiftKey && event.which == 13) {\n",
       "        this.canvas_div.blur();\n",
       "        // select the cell after this one\n",
       "        var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n",
       "        IPython.notebook.select(index + 1);\n",
       "    }\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
       "    fig.ondownload(fig, null);\n",
       "}\n",
       "\n",
       "\n",
       "mpl.find_output_cell = function(html_output) {\n",
       "    // Return the cell and output element which can be found *uniquely* in the notebook.\n",
       "    // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
       "    // IPython event is triggered only after the cells have been serialised, which for\n",
       "    // our purposes (turning an active figure into a static one), is too late.\n",
       "    var cells = IPython.notebook.get_cells();\n",
       "    var ncells = cells.length;\n",
       "    for (var i=0; i<ncells; i++) {\n",
       "        var cell = cells[i];\n",
       "        if (cell.cell_type === 'code'){\n",
       "            for (var j=0; j<cell.output_area.outputs.length; j++) {\n",
       "                var data = cell.output_area.outputs[j];\n",
       "                if (data.data) {\n",
       "                    // IPython >= 3 moved mimebundle to data attribute of output\n",
       "                    data = data.data;\n",
       "                }\n",
       "                if (data['text/html'] == html_output) {\n",
       "                    return [cell, data, j];\n",
       "                }\n",
       "            }\n",
       "        }\n",
       "    }\n",
       "}\n",
       "\n",
       "// Register the function which deals with the matplotlib target/channel.\n",
       "// The kernel may be null if the page has been refreshed.\n",
       "if (IPython.notebook.kernel != null) {\n",
       "    IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
       "}\n"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<img src=\"\" width=\"1100\">"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "img1 = combine_observations_multichannel(preprocessed_observations)\n",
    "img2 = combine_observations_singlechannel(preprocessed_observations)\n",
    "\n",
    "plt.figure(figsize=(11, 7))\n",
    "plt.subplot(121)\n",
    "plt.title(\"Multichannel state\")\n",
    "plt.imshow(img1, interpolation=\"nearest\")\n",
    "plt.axis(\"off\")\n",
    "plt.subplot(122)\n",
    "plt.title(\"Singlechannel state\")\n",
    "plt.imshow(img2, interpolation=\"nearest\", cmap=\"gray\")\n",
    "plt.axis(\"off\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Exercise solutions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. to 7."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See Appendix A."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. BipedalWalker-v3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Exercise: _Use policy gradients to tackle OpenAI gym's \"BipedalWalker-v3\"._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "env = gym.make(\"BipedalWalker-v3\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "obs = env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "img = env.render(mode=\"rgb_array\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(img)\n",
    "plt.axis(\"off\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 2.74730590e-03, -2.55714543e-05,  1.98919237e-03, -1.59998274e-02,\n",
       "        9.17998850e-02, -2.62505747e-03,  8.60360265e-01,  3.34233418e-03,\n",
       "        1.00000000e+00,  3.22045349e-02, -2.62487587e-03,  8.53911370e-01,\n",
       "        1.85646505e-03,  1.00000000e+00,  4.40814108e-01,  4.45820212e-01,\n",
       "        4.61422890e-01,  4.89550292e-01,  5.34102917e-01,  6.02461159e-01,\n",
       "        7.09149063e-01,  8.85932028e-01,  1.00000000e+00,  1.00000000e+00])"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "obs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can find the meaning of each of these 24 numbers in the [documentation](https://github.com/openai/gym/wiki/BipedalWalker-v3)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Box(4,)"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.action_space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-1., -1., -1., -1.], dtype=float32)"
      ]
     },
     "execution_count": 85,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.action_space.low"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1., 1., 1., 1.], dtype=float32)"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.action_space.high"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a 4D continuous action space controling each leg's hip torque and knee torque (from -1 to 1). To deal with a continuous action space, one method is to discretize it. For example, let's limit the possible torque values to these 3 values: -1.0, 0.0, and 1.0. This means that we are left with $3^4=81$ possible actions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import product"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(81, 4)"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "possible_torques = np.array([-1.0, 0.0, 1.0])\n",
    "possible_actions = np.array(list(product(possible_torques, possible_torques, possible_torques, possible_torques)))\n",
    "possible_actions.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "# 1. Specify the network architecture\n",
    "n_inputs = env.observation_space.shape[0]  # == 24\n",
    "n_hidden = 10\n",
    "n_outputs = len(possible_actions) # == 625\n",
    "initializer = tf.variance_scaling_initializer()\n",
    "\n",
    "# 2. Build the neural network\n",
    "X = tf.placeholder(tf.float32, shape=[None, n_inputs])\n",
    "\n",
    "hidden = tf.layers.dense(X, n_hidden, activation=tf.nn.selu,\n",
    "                         kernel_initializer=initializer)\n",
    "logits = tf.layers.dense(hidden, n_outputs,\n",
    "                         kernel_initializer=initializer)\n",
    "outputs = tf.nn.softmax(logits)\n",
    "\n",
    "# 3. Select a random action based on the estimated probabilities\n",
    "action_index = tf.squeeze(tf.multinomial(logits, num_samples=1), axis=-1)\n",
    "\n",
    "# 4. Training\n",
    "learning_rate = 0.01\n",
    "\n",
    "y = tf.one_hot(action_index, depth=len(possible_actions))\n",
    "cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=logits)\n",
    "optimizer = tf.train.AdamOptimizer(learning_rate)\n",
    "grads_and_vars = optimizer.compute_gradients(cross_entropy)\n",
    "gradients = [grad for grad, variable in grads_and_vars]\n",
    "gradient_placeholders = []\n",
    "grads_and_vars_feed = []\n",
    "for grad, variable in grads_and_vars:\n",
    "    gradient_placeholder = tf.placeholder(tf.float32, shape=grad.get_shape())\n",
    "    gradient_placeholders.append(gradient_placeholder)\n",
    "    grads_and_vars_feed.append((gradient_placeholder, variable))\n",
    "training_op = optimizer.apply_gradients(grads_and_vars_feed)\n",
    "\n",
    "init = tf.global_variables_initializer()\n",
    "saver = tf.train.Saver()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try running this policy network, although it is not trained yet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_bipedal_walker(model_path=None, n_max_steps = 1000):\n",
    "    env = gym.make(\"BipedalWalker-v3\")\n",
    "    frames = []\n",
    "    with tf.Session() as sess:\n",
    "        if model_path is None:\n",
    "            init.run()\n",
    "        else:\n",
    "            saver.restore(sess, model_path)\n",
    "        obs = env.reset()\n",
    "        for step in range(n_max_steps):\n",
    "            img = env.render(mode=\"rgb_array\")\n",
    "            frames.append(img)\n",
    "            action_index_val = action_index.eval(feed_dict={X: obs.reshape(1, n_inputs)})\n",
    "            action = possible_actions[action_index_val]\n",
    "            obs, reward, done, info = env.step(action[0])\n",
    "            if done:\n",
    "                break\n",
    "    env.close()\n",
    "    return frames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "frames = run_bipedal_walker()\n",
    "plot_animation(frames)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Nope, it really can't walk. So let's train it!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 1000/1000"
     ]
    }
   ],
   "source": [
    "n_games_per_update = 10\n",
    "n_max_steps = 1000\n",
    "n_iterations = 1000\n",
    "save_iterations = 10\n",
    "discount_rate = 0.95\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    init.run()\n",
    "    for iteration in range(n_iterations):\n",
    "        print(\"\\rIteration: {}/{}\".format(iteration + 1, n_iterations), end=\"\")\n",
    "        all_rewards = []\n",
    "        all_gradients = []\n",
    "        for game in range(n_games_per_update):\n",
    "            current_rewards = []\n",
    "            current_gradients = []\n",
    "            obs = env.reset()\n",
    "            for step in range(n_max_steps):\n",
    "                action_index_val, gradients_val = sess.run([action_index, gradients],\n",
    "                                                           feed_dict={X: obs.reshape(1, n_inputs)})\n",
    "                action = possible_actions[action_index_val]\n",
    "                obs, reward, done, info = env.step(action[0])\n",
    "                current_rewards.append(reward)\n",
    "                current_gradients.append(gradients_val)\n",
    "                if done:\n",
    "                    break\n",
    "            all_rewards.append(current_rewards)\n",
    "            all_gradients.append(current_gradients)\n",
    "\n",
    "        all_rewards = discount_and_normalize_rewards(all_rewards, discount_rate=discount_rate)\n",
    "        feed_dict = {}\n",
    "        for var_index, gradient_placeholder in enumerate(gradient_placeholders):\n",
    "            mean_gradients = np.mean([reward * all_gradients[game_index][step][var_index]\n",
    "                                      for game_index, rewards in enumerate(all_rewards)\n",
    "                                          for step, reward in enumerate(rewards)], axis=0)\n",
    "            feed_dict[gradient_placeholder] = mean_gradients\n",
    "        sess.run(training_op, feed_dict=feed_dict)\n",
    "        if iteration % save_iterations == 0:\n",
    "            saver.save(sess, \"./my_bipedal_walker_pg.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "frames = run_bipedal_walker(\"./my_bipedal_walker_pg.ckpt\")\n",
    "plot_animation(frames)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not the best walker, but at least it stays up and makes (slow) progress to the right.\n",
    "A better solution for this problem is to use an actor-critic algorithm, as it does not require discretizing the action space, and it converges much faster. Check out this nice [blog post](https://towardsdatascience.com/reinforcement-learning-w-keras-openai-actor-critic-models-f084612cfd69) by Yash Patel for more details."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. Pong DQN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's explore the `Pong-v0` OpenAI Gym environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "gq8yjOdZx9yS"
   },
   "outputs": [],
   "source": [
    "import gym\n",
    "\n",
    "env = gym.make('Pong-v0')\n",
    "obs = env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(210, 160, 3)"
      ]
     },
     "execution_count": 95,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "obs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Discrete(6)"
      ]
     },
     "execution_count": 96,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.action_space"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see the observation space is a 210x160 RGB image. The action space is a `Discrete(6)` space with 6 different actions: actions 0 and 1 do nothing, actions 2 and 4 move the paddle up, and finally actions 3 and 5 move the paddle down. The paddle is free to move immediately but the ball does not appear until after 18 steps into the episode.\n",
    "\n",
    "Let's play a game with a completely random policy and plot the resulting animation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A helper function to run an episode of Pong. It's first argument should be a\n",
    "# function which takes the observation of the environment and the current\n",
    "# iteration and produces an action for the agent to take.\n",
    "\n",
    "def run_episode(policy, n_max_steps=1000, frames_per_action=1):\n",
    "    obs = env.reset()\n",
    "    frames = []\n",
    "    for i in range(n_max_steps):\n",
    "        obs, reward, done, info = env.step(policy(obs, i))\n",
    "        frames.append(env.render(mode='rgb_array'))\n",
    "        if done:\n",
    "            break\n",
    "    return plot_animation(frames)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAANMAAAEACAYAAAAp2kPsAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAA8BJREFUeJzt3T1uE1EYQFEGZQMwW6FJ+kggFoPE\nQhAsBlGkTxq2MmxhKCIhEYwtx9f2xHNOF/lHr7n65jlv7GGe51fA4V6fewFwKcQEETFBREwQERNE\nxAQRMUFETBARE0TEBJGrcy9gk2EYtp5x+vr+zamWAn98+vFr2Pb4ImM6Riy3N9d7Pf/u/uGg1296\nDx79/Pxx79e8+/L9CCtpucyDiJggIiaILHLPdAy79i+H7qme8x482rQfes6+6txMJoiICSJigshq\n9kz2MxybyQQRMUFETBBZzZ7pKefmqJlMEBETRMQEETFBZLUfQOz6J259MJb/e4mHWjcxmSAiJoiI\nCSLDEn/s7NuHt8tbFKu369uJTCaIiAkii7zMm6ZpeYti9cZxdJkHpyAmiIgJImKCiJggIiaIiAki\nYoKImCCyyBMQDrpyqKc3HBa/POigK5yImCAiJoiICSJigoiYILLa783jshUfhe/LZIKImCAiJoiI\nCSJigoiYICImiIgJImKCiJggIiaIiAkiYoKImCAiJoiICSJigoiYICImiIgJImKCiJggIiaIiAki\nYoKImCAiJoiICSJigoiYICImiIgJImKCiJggIiaIiAkiYoKImCAiJoiICSJigoiYICImiIgJImKC\niJggIiaIiAkiYoKImCAiJoiICSJigoiYICImiIgJImKCiJggIiaIiAkiYoKImCAiJoiICSJigoiY\nICImiIgJImKCiJggIiaIiAkiYoKImCByde4FHMvtzfVff9/dP5xpJayFyQQRMUFETBARE0TEBBEx\nQURMEBETRMQEETFBREwQERNExAQRMUFETBARE0Qu9uZANwNyaiYTRMQEETFBREwQERNExAQRMUFE\nTBARE0TEBBExQURMEBETRMQEETFBREwQERNExAQRMUFETBARE0TEBBExQURMEBETRMQEETFBREwQ\nERNExAQRMUFETBARE0TEBBExQURMEBETRMQEETFBREwQERNExAQRMUFETBARE0TEBBExQURMEBET\nRMQEETFBREwQERNExAQRMUFETBARE0TEBBExQURMEBETRMQEETFBREwQERNExASRYZ7nc6/hH9M0\nLW9RrN44jsO2x00miIgJImKCiJggIiaIiAkiYoKImCAiJoiICSJigoiYICImiIgJImKCiJggssib\nA+ElMpkgIiaIiAkiYoKImCAiJoiICSJigoiYICImiIgJImKCiJggIiaIiAkiYoKImCAiJoiICSJi\ngoiYICImiIgJImKCiJggIiaI/AbgAEzkApWdfAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "run_episode(lambda obs, i: np.random.randint(0, 5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The random policy does not fare very well. So let's try to use the DQN and see if we can do better.\n",
    "\n",
    "First let's write a preprocessing function to scale down the input state. Since a single observation does not tell us about the ball's velocity, we will also need to combine multiple observations into a single state. Below is the preprocessing code for this environment. The preprocessing algorithm is two-fold:\n",
    "\n",
    "1. Convert the image in the observation to an image to only black and white and scale it down to 80x80 pixels.\n",
    "\n",
    "2. Combine 3 observations into a single state which depicts the velocity of the paddles and the ball."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [],
   "source": [
    "green_paddle_color = (92, 186, 92)\n",
    "red_paddle_color = (213, 130, 74)\n",
    "background_color = (144, 72, 17)\n",
    "ball_color = (236, 236, 236)\n",
    "\n",
    "def preprocess_observation(obs):\n",
    "    img = obs[34:194:2, ::2].reshape(-1, 3)\n",
    "    tmp = np.full(shape=(80 * 80), fill_value=0.0, dtype=np.float32)\n",
    "    for i, c in enumerate(img):\n",
    "        c = tuple(c)\n",
    "        if c in {green_paddle_color, red_paddle_color, ball_color}:\n",
    "            tmp[i] = 1.0\n",
    "        else:\n",
    "            tmp[i] = 0.0\n",
    "    return tmp.reshape(80, 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAApIAAAGgCAYAAAAZ0oddAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xe4ZVV5P/Dvq1hQmjIaARV/GitG\njbFB1GhsEUWMLSqKNbYYjWjsvfcWu6jYUGMvaBS7CGJvCBpRCAJShiKgIMr6/bH2wcPl3jvDYuDO\nMJ/P89yHe3ZZZ+199uV8z7vWPlOttQAAwHl1sZXuAAAAGyZBEgCAIYIkAABDBEkAAIYIkgAADBEk\nAQAYIkgOqqqnV9We63rbtWirVdVfn4/9H1xV+62LvlyYquotVfWsC6jt61XVd6uqLoj2WVxVfbuq\ndljpfgAwTpDM2eHqJ1X1+6r6bVW9uaq2Wm6f1tqLW2sPX5v2z8u251dV3XV6gz6tqlZX1fur6soX\nxnOvK4uF3dbao1prL7iAnvIFSV7Zpi9VrarHTsHyjKraa5H+Xaaq3lRVx1fVyVX19bl1VVUvm879\n6un38x1Qq+oWVbVvVZ1QVcdV1Yerapu59betqq9M/Tlskf2vNq3/fVUdUlW3X+a59qqqP1bVqdPz\n7VtV11mwzTZV9faqOmra7lfTfteZe742rTu1qo6Zztkl5pp5ZZLnn99zA8DK2eiDZFU9McnLkvxn\nki2T3CLJ9kn2rapLLrHPJhdeD9deVd0ryd5JXptkVZIdkpyRZL+qutyF2I/18vwsZgpjt03yibnF\nRyV5YZJ3LrHb25JcPsl1p/8+YW7dI5LcPckNk9wgyS5JHrkOunq56Xmvln59npLkXXPrT5v6+59L\n7P+BJD9IsnWSZyT5SFVdYZnne3lrbbMk2yU5Msk7Ziuqausk+ye5TJJbJdk8yY2TfC3JHRa0s9XU\nzt8k2THJv82t+1SS21bVlZbpBwDrs9baRvuTZIskpya5z4LlmyU5LslDp8fPTfKRJO9L8rskD5+W\nvW9un92THJ5kdZJnJTksye3n9n/f9PvVkrQkD0ryf0mOT/KMuXZuluSAJCclOTrJG5Jccm59S/LX\nixxLTc//5AXLL5bkp0mePz1+cJJvTu2enOSQJLeb2/7BSX6VHlR+nWS3uXUPTXJwkhOTfD7J9gv6\n9W9J/nfa783pVb75vnwyyR7T709Ncuj0PD9L8s/T8usmOT3Jn6fX5qRp+V5JXjjX1r8m+WWSE9ID\nybYL+vKoqS8nJXljklriGtg9yReXWPfCJHstWHad6RrYYol99k/yiLnHD0vyrSW2fUqSA5NsMj1+\ndJKDklx6La7dGyc5ZZHlt09y2IJl10r/QLH53LJvJHnUEm0vPNc7JzltwXn5UZKLLdO/q02vwyZz\ny16e5G0Ltts3yYMuzL97P378+PGz7n429orkTkkuneRj8wtba6cm+WzOWV3ZNT1MbpXk/fPbV9X1\nkrwpyW5JtkmvbG63hue+ZZJrJ7ldkmdX1XWn5X9Or3CtSq/g3C7JY9biWK6d5KpJPrzgWM5K8tEF\nx3Lz9BC3Kslzknysqi5fVZdN8vokd26tbZ5+fn44HeOuSZ6e5B5JrpAeRD6woA93n9q+3rTuX2bD\nulNF9I5JPjhte2h6NWvLJM9L8r6q2qa1dnB6CDygtbZZa+1cUwyq6h+TvCTJfdLP9+Fz7c7cNclN\n06uC90lypyXO298k+fkS6xZzs+n5njcNbf+kqu45t36H9JA186Np2WJekR7wnllV10zy4iQPaK2d\nvhb9uHV66FwbOyT5VWvtlLXs19mma+J+6aF95vZJPj5dW2ulqrZNfw2+tWDVwenVWwA2QBt7kFyV\n5PjW2p8WWXf0tH7mgNbaJ1prZ7XW/rBg23sl+XRrbb/W2h+TPDu9GrOc57XW/tBa+1H6m/oNk6S1\n9r3W2rdaa39qrR2W5K1J/mEtj2XW7zUdy7FJXttaO7O19qH0IHWXad1ZSa5fVZu21o5urc3CyqOS\nvKS1dvB0vl6c5EZVtf1cuy9prZ0wnZ9vpJ+DW03r7pV+Do+ajvPDrbWjpvP5ofTq4c3W4jiTHtjf\n2Vr7fmvtjCRPS7JjVV1tbpuXttZOaq39X5KvJLnREm1tlV4VXVtXTnL99Grutkkem+Tdcx8ENpvW\nzZycZLPF5klOQWz3JI9Lr6q+vLX2gzV1oKpukH6NLTWMvdDCPs36tfky+zypqk5KPze3TPLAuXWr\nkvx2rj93q6qTquqUqvrCgnaOn9o5Mn34/SML1p+S/hoAsAHa2IPk8UlWLTGnb5tp/cwRy7Sz7fz6\n1trv04e4l/Pbud9/n/5mn6q6VlV9Zrrp53fpgW3VYg0sMOvrNousW3gsR7bW5oPu4elDw6cl+Zf0\n0Hh0Ve0zd5PF9kleNwWGk9KHlCvnrLzOn4OWXiW837To/pmr5FbV7lX1w7n2rr+Wx5n083343HOd\nmn6+5/uy6PldxIlZPlAt9IckZ6YP/f6xtfa19KB6x2n9qelTJma2SHLqgvN9tunDwlfSh4LfuKYn\nr37H/ueSPL619o217PPCPs36tVyAfuVUDb5a+jFfe27d6sxdZ621T03bPiHJwnnFq6Z1l0mfUvH5\nBes3T59+AMAGaGMPkgekDy3eY35hVW2W5M5JvjS3eLkK49HplarZ/pum39Qw4s3p8xav2VrbIn04\neW3u+v15kt8kuff8wqq6WJJ75pzHst2CCtlV028wSWvt8621O6QHhUOSvH3a5ogkj2ytbTX3s2lr\nbf+5dhaeow8kuddUtbx5+hB7psdvT6/mbT0FjZ/OHeeaqrlHpQfb2TFeNv18H7mG/Rbz4/Q5hOdl\n+4Xm+3tQzjlUe8MsMwRdVXdJn8LwpfSh7iVN5+2LSV7QWnvv2nZ4ev6rV9V8YF62XzNTRffx6R8i\nNp0WfynJ3adra61MVeq9ktyiquY/MFw355wKAMAGZKMOkq21k9Pn5/1XVf1TVV1iGh797/RQtrZv\n1h9JsktV7TTd6f3crF34W8zm6TdznDpVAx+9NjtNFa8npc+3u39VXXq6G3bP9OrTa+Y2v2KSx03H\ne+/0N/PPVtVfVdWuUzA7I72SNZsH95YkT6vpe/+qastp3+X69IP0SuieST7fWptVni6bHr6Om9p6\nSHpFcuaYJFde6q759ID6kKq6UVVdKr1qe+BU3Tuv9k1y46q69GxBVW0yPb54kotP53JWtf56+k1S\nT5u2+/v0u75nlbb3JNmjqrab5gU+MT1AncsUqPZMv3nrQenX0M5LbLtdki8neUNr7S2LrL/Y1OdL\n9Id16dn5a639In2u63Om5f+cPnf0o2tzglpr+6aH90dMi16dfhf5e6vqGtVtnqWnD2R6nR6YXile\nPS27dJK/S38NANgAbdRBMklaay9Pr/q9Mj3AHZhefbvdNP9ubdo4KMm/pw/lHp0ewI5ND2Pn1ZPS\nh4FPSa/afWhtd5zmGj4wfYhxdfrd0Jsm+fvW2vxQ+4FJrpke8l6U5F7T+osl2SM9NJyQPjfz0VPb\nH0//mqQPTkPuP02v2q7J3uk3Z+w918+fJXlVekX4mPQbXr45t8+X06tlv62q+SH52f5fTL8z/qPp\n5/saSe67Fn05l9baMdPz7Tq3+Jnpw7lPTfKA6fdnTtufOW27c/o8w7cn2b21dsi071uTfDrJT9LP\n0T7TssW8LcknW2ufnc7/w5LsOX29zkIPT3L1JM+d+27GU+fW33rq52fTK8x/SDI/X/G+SW6SPpT/\n0vTX/LhlTs1Cr0jy5Kq6VGvt+PSvyTo9yX7p1+oP0z8ELfzgc9LUz2PSK693mxvm3yXJV2fzZgHY\n8NQSU7c4H6ah8ZPSh6d/vdL9YXnTXffvTnKzpeYysu5V1YFJHtZa++lK9wWAMYLkOlJVu6TPHav0\natvNk9xYMAEALqo2+qHtdWjX9CHho9KHje8rRAIAF2UqkgAADFGRBABgiCAJAMCQxf5FlxVXVUPj\n7a/9p8ut664AFyGP/9wJo9/vCsAi1ssguTEHwjvstOM6b3Pf/Q9Y522y8r67x13WvNF5dJNX77PO\n2wTgosvQNgAAQwRJAACGCJIAAAwRJAEAGCJIAgAwRJAEAGDIevn1Pyxuua/xuSC+NogN13Jf43NB\nfG0QABsnFUkAAIYIkgAADBEkAQAYIkgCADBEkAQAYIggCQDAEEESAIAhgiQAAEMESQAAhgiSAAAM\nESQBABgiSAIAMESQBABgiCAJAMAQQRIAgCGCJAAAQwRJAACGCJIAAAwRJAEAGCJIArDRqqrbVNVv\nzmcbe1XVC9dVny4sVXVQVd3mAmr7kVX12guibcadl2u1qr5dVTusabtNzn+3uLDcYacdV7oLbCC+\nu8ddVroLrICqOizJXyX5c5LTknwuyWNba6euZL82ZFVVSZ6U5BFJrpzkuCTvT/Lc1toZK9m386Kq\n9krym9baM2fLWmtrDAmDz3XJJM9Mcou5Zf+Y5JVJ/jrJ8Ule2lp729z6+yd5SZJVSfZN8tDW2gnr\noC83SvJfSW6Q5JQkb22tvWBu/e2SvDHJVZMcmOTBrbXDl2nvvkmekOT66X9jv07y7iRvbq2189vf\n9cwrkzw/yT2X20hFEuCiZZfW2mZJbpzkJulv6OdQ3Tr7//+6bm898/r0ELl7ks2T3DnJ7ZL894XZ\niarakAo/uyY5pLV2ZJJU1SWSfDzJW5NsmeRfkry6qm44rd9hWvfA9A9Cv0/ypnXUl72TfD3J5ZP8\nQ5LHVNXdpuddleRjSZ41rf9ukg8t1VBVPTHJ65K8IsmVpr4+KsnfJ7nkEvtcfB0dx0r4VJLbVtWV\nltvoovqHD7BRm97EP5deOUlVfbWqXlRV30x/o756VW1ZVe+oqqOr6siqeuHsja+qHlxV36yqN1TV\nyVV1yFS9yTLtbVtVn6qqE6rql1X1r3PbX7yqnl5Vh1bVKVX1vaq6yrTuOlW177Tfz6vqPnP77VxV\nP5v2ObKqnjQtX1VVn6mqk6b9vjELs1M/PlpVx1XVr6vqcXPtbToN751YVT9LctOlzmFVXTPJY5Ls\n1lo7oLX2p9baQekVmn+aqmwzq6ZjOKWqvlZV209tVFW9pqqOrarfVdVPqmr2mlyqql5ZVf9XVcdU\n1VuqatNp3W2q6jdV9ZSq+m2Sd1XVwVV117n+bTId442nxx+uqt9Or9fXZ8OSVfWIJLsleXJVnVpV\nn56WH1ZVt5/ry2ur6qjp57VVdakFfXnidBxHV9VDlrn87pzka3OPL59kiyTvbd13khyc5HrT+t2S\nfLq19vWpev6sJPeoqs0XeU12qqrj566dG06v5XWW6MvVkry/tfbn1tqhSfZLMqvE3iPJQa21D7fW\nTk/y3CQ3XKytqtoyvTr3mNbaR1prp0zH8oPW2m6z6vR0bb25qj5bVaelB7G7VNUPptf/iKp67ly7\n+1TVvy94rh9X1T+v4drZtKpeVVWHT6/3fnPXzqLXwWKq6q5V9cPp72j/qrrBbN10Tr6X5E5L7Z8I\nkgAXSdMb7c5JfjC3+IHp1bXNkxyeZK8kf0ofbvzbJHdM8vC57W+e5ND04cbnJPlYVV1+mfY+mOQ3\nSbZNcq8kL54LW3skud/Upy2SPDTJ76vqsulDmXsnuWKS+yZ5U1XNQsY7kjyytbZ5eij+8rT8idNz\nXSG9MvT0JG0Kk59O8qMk26VXD/+jqmZvhs9Jco3p505JHrTMabxd+nDwt+cXttaOSPKtJHeYW7xb\nkhdM5+qH6cPfST+nt05yrfRq3H2SrJ7WvXRafqP012C7JM+ea/NK6SFs+/Tz/IHpHM7cKcnxrbXv\nT48/l+Sa6efx+7M+TEPI70/y8tbaZq21XRY51mekD0XfKMkNk9ws56xmX2nq/3ZJHpbkjVV1uUXa\nSZK/SfLz2YPW2jFT3x9S/QPFjtMx7TdtskP66zXb/tAkf5zOzTm01vZPr16+ewpO70vyrNbaIUv0\n5bVJdq+qS1TVtZPsmOSLSzzvaenX+2LBa8ckl0ryySWeZ979k7wo/e9iv/Qh8N2TbJXkLkkeXVV3\nn7Z9d5IHzHasXqXdLsk+Wf7aeWWSv0uyU/o18uQkZ03rFr0OFqqqv03yziSPTLJ1+nn91OwDxOTg\n9OthSYIkwEXLJ6rqpPQ3sK8lefHcur1aawe11v6U/uazc5L/aK2d1lo7Nslr0oPczLFJXttaO7O1\n9qH0cHCXJdq7UvoQ31Naa6e31n6YZM/0N9CkB9RnttZ+PlVyftRaW53krkkOa629a6r4/SDJR5Pc\ne9rvzCTXq6otWmsnzoWmM5Nsk2T7qX/fmOao3TTJFVprz2+t/bG19qskb587rvskeVFr7YQpEL5+\nmXO5KsnRS6w7elo/s89UUTsjPZTtOIX5M9MDxXWSVGvt4Nba0VVV6eHwCVNfTkl/rebP/1lJntNa\nO6O19of0sH23qrrMtP7+6QEtSdJae+dUKTsjf6mubbnM8c3bLcnzW2vHttaOS/K89A8KM2dO689s\nrX02yalJrr1EW1ulz0ec94H0kHxGkm8kecZ0/pNksyQnL9j+5PTztpjnpgerbyc5Mn2O41I+k/6h\n5g9JDknyjqkiel6fd1V6aP/TbMFUwTupqv5QVbee2/aTrbVvttbOmv4Wvtpa+8n0+Mfp5+Ifpm0/\nleRa1avfST/nH2qt/TFLXzsXS/8g9vjW2pFTtXX/WVX0PFwHj0ifM3rg1Ma701+fW8xtc0r667kk\nQRLgouXurbWtWmvbt9YeMwWQmSPmft8+ySWSHD29GZ6UXpG44tw2Ry64geDw9GrjYu1tm2QWiOa3\n3276/Srp1Z6Ftk9y81kfpn7slh5Mkz6MvHOSw6sPGc/uOnxFkl8m+UJV/aqqnjrX3rYL2nt6etVy\n1s/5fi95Y0X6TSHbLLFum2n9zNltTsOzJyTZtrX25SRvSA87x1bV26pqi/RK6mWSfG+un/8zLZ85\nbhpenLX7y/QK0S5TmLxbericTR14afWpA79Lcti023zYXc62Oee5WPhar54PUenTGTZboq0TMxfG\npqHiD6Z/qLhkesXvyVU1+1ByanqVet4WOXcYTZK01s5Mr6ZfP8mrlrrJZaqe/0/6kPSl06/BO1XV\nYwaed3X69IWz56q21nZqrW01rZvPU/PXV6rq5lX1lerTEE5On1e5amrj9PR5mQ+YAuL9krx3WrfU\ntbNqOp5z/T2dx+tg+yRPXPC3cpWc83XfPMlJi+x7NkESYOMx/4Z7RHr1YdUUPLdqrW3Rznkn73ZT\n5WzmqkmOWqK9o5JcfsG8tqumV4xmz3eNRfp0RJKvzfVhq2n49dFJ0lr7Tmtt1/SA+4lMN7lMFZcn\nttaunh6o9qg+h/OIJL9e0N7mrbWdp+c7Ov3Ncr6PS/lykqtU1c3mF06Vxlsk+dLc4qvMrd8sveJ7\n1NTX17fW/i59TuC1kvxnegj9Q5Id5vq5Zes3Ss0sFpBmw9u7JvnZFC6TXp3cNcnt06t1V5t1Z5m2\n5h2VHixmFr7W58WPc85h6esn+UVr7fNTVe7n6UO3d57WH5S54dOqunr6MPIvFmu8qrZLn6LwriSv\nWjAUO+/qSf7cWnvPVO3+TXqgnV0LC5/3sunX6EGLtHVA+t/Lrkse9V8sPNd7p1cer9Ja2zLJW/KX\n1yXpw9u7pU+l+H1r7YCzG1r62jk9i/89rek6mHdEenV+/m/lMq21D8xtc93MDf8vZkO6C2yjsO/+\nB6x5I0hyk1fvs9JdYAM2DZF9If2N+Fnp1Zn/l+TKrbXZjRJXTPK4qnpTkrunv6l8don2jqiq/ZO8\npPoNMddKn0u327TJnkleUP0Gl1+mz6M7Mn3o8aVV9cD0N/mkz9M7Nb3icu8kn2mtnTxVWM5K+k0C\n6UOVh6YPR/55WvftJKdU1VPSh63/OPV702lI87+TPK2qDkxy2STnuNFhwTH9oqrekuT9U/++kz7M\n+K4kX2ytfXFu852r6pbT878gybemc3LT9KLN99Pnyp2e5KzW2llV9fYkr6mqx7bWjp0C0vVba59f\nqk/TOXpRelDde2755ulBZ3V6pfPFC/Y7Jj1YLeUDSZ5ZVd9JD0LPTp9/OOKz6VW3F02Pf5DkmtN8\n2a9M/bhrkpdP69+f5ICqulX6eXp+ko8tqG4nOfvrmPZKnzv71PSK4wvS5wgu9Itpl/unn7crpt8x\n/pVp/ceTvKKq7pkebJ+d5MdtkfmWrbWTqup56fN3K8nn01/PG6RfR8vZPL1af/r0oeT+Sb4w1/YB\nVXVWkldlqkZOx7rctfPO9DvfH5j+2t5s2m5N18G8tyf5eFV9Mf26vUyS2yT5emvtlKq6dPo8zOXm\nEatIAmzEZkONP0sfjvxIzjmUe2D6pP3j00PBvVqf17iU+6VXQI5Kf5N+zlzYenV6iPtCkt+lB4FN\np7Bwx/S5gUcl+W2Sl6VXpJI+Z+ywKUQ+Kn8JptdMv2ni1PRq0Ztaa19prf05PaTcKP07/o5PD7Gz\nOWLPSx+2/fXUl7PfuJfw2Gn/903P9T9Jvppzf7fe3ulVshPS33xnN1Bskf6GfeL0vKvTh+WT5Cnp\nofpb0/F9MUvPO0zSPwBMx7tTzvlVNe+Z2j8y/fX81oJd35E+1/SkqvrEIk2/MP3rb36c5CfpoWT0\nS9Y/neQ6VbXt1OdD0+f0vT79tf9a+jzYPaf1B6W/tu9Pn5e7efrd8ot5XHogfNY0pP2Q9Jt4brVw\nw9ba79LvzH5C+vn/YZKfzo5rmgt6z/Rr+8T0m8vuu7CdufZenn7T2JPTw9sx6dNBnpJk/2XOx2OS\nPL+qTkkPq4t9ddR70j9czYf35a6dJ6W/Tt9Jv+Zelp7p1nQdzB/Pd5P8a/rw+Ynp1+KD5zbZJclX\nW2vLVqZriakFK+p1d778+tcpYIP3+M+dsNjwDouoqgcneXhr7ZYr3Rc2PNW/cuh6rbX/WOm+bAiq\navckj1if/t6mqv3DWms/XW47Q9sAwDrV5v7VGpY33Tj1mKy7L2FfJ1prN1+b7QxtAwCsgOrfb3pc\n+jD53mvYfL20Xg5tr169ev3rFLDB23rrrQ1tA6xDKpIAAAwxRxLgIqKqjOYAF4jW2qIjOiqSAAAM\nESQBABgiSAIAMESQBABgiCAJAMAQQRIAgCGCJAAAQwRJAACGCJIAAAwRJAEAGCJIAgAwRJAEAGCI\nIAkAwBBBEgCAIdVaW+k+nMvr7nz59a9TwAbv8Z87oVa6DxekqvL/TuAC0Vpb9P+fKpIAAAwRJAEA\nGCJIAgAwRJAEAGCIIAkAwBBBEgCAIYIkAABDBEkAAIYIkgAADBEkAQAYIkgCADBEkAQAYIggCQDA\nEEESAIAhgiQAAEMESQAAhgiSAAAMESQBABgiSAIAMESQBABgiCAJAMAQQRIAgCGCJAAAQwRJAACG\nCJIAAAwRJAEAGCJIAgAwRJAEAGCIIAkAwBBBEgCAIYIkAABDBEkAAIYIkgAADBEkAQAYIkgCADBE\nkAQAYIggCQDAEEESAIAhgiQAAEMESQAAhgiSAAAMESQBABgiSAIAMESQBABgiCAJAMAQQRIAgCGC\nJAAAQwRJAACGCJIAAAwRJAEAGCJIAgAwRJAEAGCIIAkAwBBBEgCAIYIkAABDBEkAAIYIkgAADBEk\nAQAYIkgCADBEkAQAYIggCQDAEEESAIAhgiQAAEMESQAAhgiSAAAMESQBABgiSAIAMESQBABgiCAJ\nAMAQQRIAgCGCJAAAQwRJAACGCJIAAAwRJAEAGCJIAgAwRJAEAGCIIAkAwBBBEgCAIYIkAABDBEkA\nAIYIkgAADBEkAQAYIkgCADBEkAQAYIggCQDAEEESAIAhgiQAAEMESQAAhgiSAAAMESQBABgiSAIA\nMGSTle4AAADn1lpbcl1VXYg9WZqKJAAAQwRJAACGCJIAAAwRJAEAGCJIAgAwRJAEAGCIIAkAwBBB\nEgCAIYIkAABDBEkAAIYIkgAADBEkAQAYIkgCADBEkAQAYIggCQDAEEESAIAhm6x0B4CLtu/ucZcl\n193k1ftciD0BYF1TkQQAYIggCQDAEEESAIAhgiQAAEMESQAAhgiSAAAMESQBABgiSAIAMESQBABg\niCAJAMAQ/0QiAMB6qKpWugtrpCIJAMAQQRIAgCGCJAAAQwRJAACGCJIAAAwRJAEAGCJIAgAwRJAE\nAGCIIAkAwBBBEgCAIYIkAABDBEkAAIYIkgAADBEkAQAYIkgCADBEkAQAYMgmK90B4KLtJq/eZ6W7\nAMAFREUSAIAhgiQAAEMESQAAhgiSAAAMESQBABgiSAIAMESQBABgiCAJAMAQQRIAgCGCJAAAQ/wT\niQAMaa0tua6qLsSeXPQsPLfOJ+srFUkAAIYIkgAADBEkAQAYIkgCADBEkAQAYIggCQDAEF//AwDr\nGV/3w4ZCRRIAgCGCJAAAQwRJAACGCJIAAAwRJAEAGCJIAgAwRJAEAGCI75FMcoeddlxy3b77H3Ah\n9gQAYMOhIgkAwBAVSS4QD3jf/57j8fsecM0V6gkAcEFRkQQAYIggCQDAEEESAIAhgiQAAEMESQAA\nhgiSAAAM8fU/XCB83Q8AXPSpSAIAMESQBABgiCAJAMAQcyQBGFJVK90FYIWpSAIAMESQBABgiCAJ\nAMAQQRIAgCGCJAAAQwRJAACGCJIAAAwRJAEAGCJIAgAwRJAEAGCIIAkAwBBBEgCAIYIkAABDBEkA\nAIZsstIdWB/su/8BK90FAIANjookAABDBEkAAIYIkgAADBEkAQAYIkgCADBEkAQAYIggCQDAEEES\nAIAhgiQAAEMESQAAhgiSAAAMESQBABgiSAIAMESQBABgiCAJAMAQQRIAgCGCJAAAQwRJAACGCJIA\nAAwRJAEAGCJIAgAwRJAEAGCIIAkAwBBBEgCAIYIkAABDBEkAAIYIkgAADBEkAQAYIkgCADBEkAQA\nYIggCQDAEEESAIAhgiQAAEMESQAAhgiSAAAMESQBABgiSAIAMESQBABgiCAJAMAQQRIAgCGCJAAA\nQwRJAACGCJIAAAwRJAEAGCJIAgAwRJAEAGCIIAkAwBBBEgCAIYIkAABDBEkAAIYIkgAADBEkAQAY\nIkgCADBEkAQAYIggCQDAEEHfAWbJAAAB7ElEQVQSAIAhgiQAAEMESQAAhgiSAAAMESQBABgiSAIA\nMESQBABgiCAJAMAQQRIAgCGCJAAAQwRJAACGCJIAAAwRJAEAGCJIAgAwRJAEAGCIIAkAwBBBEgCA\nIZusdAcAWDdaa7XSfQA2LiqSAAAMESQBABgiSAIAMESQBABgiCAJAMAQQRIAgCGCJAAAQwRJAACG\nCJIAAAwRJAEAGCJIAgAwRJAEAGCIIAkAwBBBEgCAIYIkAABDqrW20n04l9WrV69/nQI2eFtvvXWt\ndB8ALkpUJAEAGCJIAgAwRJAEAGCIIAkAwBBBEgCAIYIkAABDBEkAAIYIkgAADBEkAQAYIkgCADBE\nkAQAYIggCQDAEEESAIAhgiQAAEMESQAAhgiSAAAMESQBABgiSAIAMESQBABgiCAJAMAQQRIAgCGC\nJAAAQwRJAACGCJIAAAwRJAEAGCJIAgAwRJAEAGBItdZWug8AAGyAVCQBABgiSAIAMESQBABgiCAJ\nAMAQQRIAgCGCJAAAQwRJAACGCJIAAAwRJAEAGCJIAgAwRJAEAGCIIAkAwBBBEgCAIYIkAABDBEkA\nAIYIkgAADBEkAQAYIkgCADBEkAQAYIggCQDAEEESAIAhgiQAAEMESQAAhgiSAAAMESQBABgiSAIA\nMOT/A0y5NVatsS4uAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 792x504 with 2 Axes>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "obs = env.reset()\n",
    "for _ in range(25):\n",
    "    obs, _, _, _ = env.step(0)\n",
    "\n",
    "plt.figure(figsize=(11, 7))\n",
    "plt.subplot(121)\n",
    "plt.title('Original Observation (160 x 210 RGB)')\n",
    "plt.imshow(obs)\n",
    "plt.axis('off')\n",
    "plt.subplot(122)\n",
    "plt.title('Preprocessed Observation (80 x 80 Grayscale)')\n",
    "plt.imshow(preprocess_observation(obs), interpolation='nearest', cmap='gray')\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [],
   "source": [
    "def combine_observations(preprocess_observations, dim_factor=0.75):\n",
    "    dimmed = [obs * (dim_factor ** idx)\n",
    "              for idx, obs in enumerate(reversed(preprocess_observations))]\n",
    "    return np.max(np.array(dimmed), axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAF4CAYAAAB9xrNzAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEWVJREFUeJzt3HuwdXdd3/HPN3kSSMgN5CKRCDao\nhUBbRQ04KM8gCgkqGbHaglDoKIKDNqPSVMZLBAS1ZXhgxJF6i5pELral1U6dGuEJMJJaKB0vWBHJ\nDUgIJAFCSTTAr3+s36krh3POcwk55xvzes2cyd77t87av7323u+19trnSY0xAkBfx+z1BADYmVAD\nNCfUAM0JNUBzQg3QnFADNCfUe6iqLqyqi3cY//Oq2n8X3O/+qvrgnVzHRVX1si/UnHbLXbVN706q\n6r9V1b/4AqznYVU1qmrfF2JebE+ot1BVz6iqd1XVp6rquvnCfvxuz2OMcdYY4+Bu328tXlRVf1VV\nt1bVNVX1iqq6127P5c7YameyV9t0t1XVi6vqyvka/mBVvWFjbIxxzhjjN/Z4fqdV1a9V1fVVdUtV\nva+q/s1qfFTVw49gfQer6nvvmtnuPaHepKp+OMmBJC9P8qAkX5rkF5M8bS/ntctek+R5SZ6d5OQk\n5yT5piRv3M1JOFI7OvNo+VlJnjTGOCnJ1yT5w72d1ed5VZKTkjwiyalJvj3J+/d0Rp2NMfzMnywv\nmE8l+ac7LHOvLCH/8Pw5kORec2x/kg8m+ddJbkhyXZLzkpyb5H1Jbkry4tW6LkzyO0nekOSWJP8r\nyT9ejV+V5c22sewbk/zmXPbPk3zNatnTk/yHJB9NcmWSH1qNnZDkoiQ3J3lvkhcl+eA2j+/Lk3w2\nyddtuv2MJH+T5Inz+kVJfinJH8z5XJ7koXOssrwRb0jyySR/muRRq+3375Jck+Qjcx0nbNp+FyS5\nPslvJfmLJN+6mse++Ri/el5/01z2E0neluSsefvzktye5G/nc/q7W2zTw3kuf2T1XD53NY9z57a8\nJcmHkvzoNtvzzCRvSXJjko8luSTJaavxC+bv35LkL5N80zbreWqS98zteW2SC3d4jf5CkgM7jB9M\n8r3z8nOSvGM+JzfP1845q2W/bG7XW5JcluS1SS6eYw9LMpLsW71/fnVuqw8leVmSY7eZw58lOW+b\nsbfN9f7f+dx9d5L7Jvm9+dzfPC8/ZC7/M1les7fN5X9h3v4Ps7w+b5rb9rv2ujFH3aa9nkCnnyRP\nSfKZjRfeNsu8JMkVSR6Y5AFJ/ijJS+fY/vn7P5nkuCTfN19Yl2Y5Mj0rya1Jvmwuf2GWmHznXP5H\n5xvluDl+Ve4Y6ttmII5N8ookV8yxY5K8e97v8Un+QZIPJHnyHP/ZJG9Pcr8swf2zbB/q5ye5epux\ny5O8Yl6+aL55vzFL8F6d5B1z7MlzPqdlifYjkjx4jr0qyX+Zczk5ye+u1rmx/X5urvOE+ZguWc3h\nqUn+YnX9X871bET3f6/GLkrysk2PYb1ND+e5fMl8bs5N8ukk953j1yX5hnn5vpk7ji222cOTfPOc\n3wOyROjAHPvKLNE9fV5/WJIzt1nP/iSPns/1P8qyk9sudN+TJU4vynI0feym8YO5Y6hvz/JaPTbJ\nC7LstGqOvzNLxI9P8vgsO4rtQv2fkrwuyX3mNv3jJN+/zRx/JcvBxnOTfPkW4yPJw1fXvyjJ05Oc\nOJ/vNyV581aPaV6/z9y2z82yc/+qLDvKR+51Z46qTXs9gU4/SZ6Z5PpDLPPXSc5dXX9ykqvm5f1Z\nQnzsvH7yfMGdvVr+3RtvsCzxvWI1dsymAFyVO4b6stWyj0xy67x8dpJrNs3zx5L8+rz8gSRPWY09\nL9uH+sfXc9o09vokvzwvX5Tk9auxk7Ic1ZyR5IlZPkE8Nskxq2Uqy1HSmavbHpfkytX2+9sk916N\nPzzLDuHEef2SJD+5zfxOm9v71NUcdwr14TyX+1bjNyR57Lx8TZLvT3LKEb7GzkvyntVjuyHJkzJ3\nzkewngNJXnWI1/Jlc3vfmOSC1djB3DHU71+NnTi34RdnOe33mY1tP8cvzhahznKa8G8yPx3N8X+e\n5K3bzO+EJC/O8n64Pctpj/WR/B1CvcXv/5MkN2/1mOb1707y9k2/87okP3Uk27nLj3PUd3Rjkvsf\n4tzo6UmuXl2/et72/9cxxvjsvHzr/O9HVuO3Zonahms3LowxPpfl4/Z6fWvXry5/Osm951wfmuT0\nqvr4xk+WN8GDVnO+dvW76/lv9rEkD95m7MFzfKu5fyrLUdzpY4y3ZPn4/dokN1TVv6+qU7IcUZ6Y\n5N2ref7+vH3DR8cYt63W+/4spz++rapOzHIu89Ikqapjq+pnq+qvq+qTWSKcJPff4fGtHc5z+ZnV\n9U/n7567p2c5yr66qi6vqsdtdQdV9aCqen1VfWjO8eKN+c3Hdn6WnfANc7ktn/uqOruq3lpVH62q\nT2T55LPt4xxjXDLGeFKWndfzk7y0qp68zeLXr37v0/PiSVm2xU2r25I7vo7WHprlk8d1q+f2dVmO\nrLea361jjJePMR6T5Wj5jUneVFX322r5qjqxql5XVVfP7fi2JKdV1bE7zOfsTe+JZ2bZAd3tCPUd\nvTPLUcF5Oyzz4Swvgg1fOm87WmdsXKiqY5I85CjWd22Wo9LTVj8njzHOnePXre9nznk7b0lyRlV9\n3frGqjojyxHy+kup9dxPynI648NJMsZ4zXwTPjLJV2T5GP6xLDuqs1bzPHUsX3htGFvM6bezHJ09\nLcl7Z+CS5BnztidlOT/6sI3p7LCutaN+LscY/3OM8bQsIXpztv+i9eVzHo8eY5yS5bTExvwyxrh0\njPH4OY+R5bTPVi7NcsrojDHGqVnO7dc2y67nefsY401J/iTJow7nsa1cl+R+cwe54Yxtlr02y3vn\n/qvn9pQxxlmHMcdPZtlO98lyTnwrP5LlVNHZczt+47x9u+f62iSXb3pPnDTGeMGh5tORUK+MMT6R\n5Zzoa6vqvLkXP66qzqmqn5+L/XaSH6+qB1TV/efy2/4t9GF4TFV9xzwyPj/Li/2KI1zHHye5paou\nqKoT5pHmo6rqa+f4G5P8WFXdt6oekuQHt1vRGON9WSJwSVU9dq7rrCxfVF42xrhstfi5VfX4qjo+\nyUuznDK5tqq+dh4BHpflo/dtST43PzH8cpJXVdUDk6SqvmSHI70Nr0/yLVnOn166uv3kLNvrxixH\n6i/f9HsfyXK+fjtH9VxW1fFV9cyqOnWMcXuW87af22bxk7N8wfWJqvqSLDusjfV8ZVU9cf7Z421Z\ndmI7reemMcZtcyf6jB3m95yqempVnVxVx1TVOVm+H/kfh3psa2OMq5O8K8mF8zE/Lsm3bbPsdUn+\ne5JXVtUp837PrKonbDPHn5ivk+Or6t5J/lWSj2f50i/5/Ofu5Czb5+PzqPunNq1y8/K/l+QrqupZ\n8z183Ly/RxzJNuhCqDcZY7wyyQ9nOVf70Sx75hdmOWpKlm+y35XlCOVPs/ylxp35hx//Ocv5tJuz\n/EnVd8w3/5HM+bNJvjXLebsrsxy5/kqWo8wk+eksH+uvzPJm+q1DrPKF8/cvzhKZ389yDvDpm5a7\nNMsb5qYkj8lytJgkp2QJ8s3zfm9M8m/n2AVZzkdeMT/CXpblSGmnx3ddlk87X5/lL2Q2/OZc/4ey\n/AXG5h3cryZ55Pzo++Z8vjvzXD4ryVXzMTw/y8fqrfx0kq/O8lcp/zXJf1yN3SvLF70fy3L64YFZ\nvlvYyg8keUlV3ZJlh7LTn0p+Msupr2uyxO/nk7xgjPGOQz+sz/PMLN8j3Jhl27why85xK8/O8qXj\ne7M897+T7U+jjSS/nuWxfzjLF65PnafQkuV00G/M5+67spyTP2Euf0WW1+Taq5N8Z1XdXFWvGWPc\nkmXn/s/m+q/P331Jfbez8c0uwCHNfzjzf8YYm49ouQs5oga2NU8XnDlPZTwly3cCW3064S7kX34B\nO/niLKdrvijLXyS9YIzxnr2d0j2PUx8AzTn1AdCcUAM0tyfnqKvK+RaATcYYW/4jJkfUAM0JNUBz\nQg3QnFADNCfUAM0JNUBzQg3QnFADNCfUAM0JNUBzQg3QnFADNCfUAM0JNUBzQg3QnFADNCfUAM0J\nNUBzQg3QnFADNCfUAM0JNUBzQg3QnFADNCfUAM0JNUBzQg3QnFADNCfUAM0JNUBzQg3Q3L69ngDA\nbhhj7DheVbs0kyPniBqgOaEGaE6oAZoTaoDmhBqgOaEGaE6oAZoTaoDmhBqgOaEGaE6oAZoTaoDm\nhBqgOaEGaE6oAZoTaoDmhBqgOaEGaE6oAZoTaoDmhBqgOaEGaE6oAZoTaoDmhBqgOaEGaE6oAZoT\naoDmhBqgOaEGaE6oAZrbt9cTANgNVbXXUzhqjqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZo\nzj94Ae4RDh48uOP4/v37d2UeR8MRNUBzQg3QnFADNCfUAM0JNUBzQg3QnFADNCfUAM0JNUBzQg3Q\nnFADNCfUAM0JNUBzQg3QnFADNCfUAM0JNUBzQg3QnFADNCfUAM0JNUBzQg3QnFADNCfUAM0JNUBz\nQg3QnFADNCfUAM0JNUBzQg3QnFADNFdjjN2/06rdv1OA5sYYtdXtjqgBmhNqgOaEGqA5oQZoTqgB\nmhNqgOaEGqC5fXs9AYDdcODAgR3Hzz///F2ayZFzRA3QnFADNCfUAM0JNUBzQg3QnFADNCfUAM0J\nNUBzQg3QnFADNCfUAM0JNUBzQg3QnFADNCfUAM0JNUBzQg3QnFADNCfUAM0JNUBzQg3QnFADNCfU\nAM0JNUBzQg3QnFADNCfUAM0JNUBzQg3QnFADNCfUAM0JNUBzNcbY/Tut2v07BWhujFFb3e6IGqA5\noQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaE\nGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNq\ngOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgB\nmhNqgOaEGqA5oQZoTqgBmhNqgOb27fUE9sqBAwd2HD///PN3aSb3LJu3u+0Mh+aIGqA5oQZoTqgB\nmhNqgOaEGqA5oQZoTqgBmqsxxu7fadXu3+kmBw8e3HF8//79uzKPv282b1fbEQ7fGKO2ut0RNUBz\nQg3QnFADNCfUAM3dY/+nTE94whP2egp3e1t9EX355ZfvwUzg7zdH1ADNCTVAc0IN0Nw99h+8HOpx\nV235d+esbLUNbTc4ev7BC8DdlFADNCfUAM3dY/+OmjvP+WjYHY6oAZoTaoDmhBqgOaEGaE6oAZoT\naoDmhBqgOaEGaE6oAZoTaoDmhBqgOaEGaE6oAZoTaoDmhBqgOaEGaE6oAZoTaoDmhBqgOaEGaE6o\nAZoTaoDmhBqguX17PYG9UlV7PQWAw+KIGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5\noQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaE\nGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNq\ngOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgB\nmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZo\nTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqC5GmPs9RwA2IEjaoDmhBqg\nOaEGaE6oAZoTaoDmhBqgOaEGaE6oAZoTaoDmhBqgOaEGaE6oAZoTaoDmhBqgOaEGaE6oAZoTaoDm\nhBqgOaEGaE6oAZoTaoDmhBqgOaEGaE6oAZoTaoDmhBqguf8HSNWGQEW2Vh4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x432 with 1 Axes>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "n_observations_per_state = 3\n",
    "\n",
    "obs = env.reset()\n",
    "for _ in range(20):\n",
    "    obs, _, _, _ = env.step(0)\n",
    "\n",
    "preprocess_observations = []\n",
    "for _ in range(n_observations_per_state):\n",
    "    obs, _, _, _ = env.step(2)\n",
    "    preprocess_observations.append(preprocess_observation(obs))\n",
    "\n",
    "img = combine_observations(preprocess_observations)\n",
    "\n",
    "plt.figure(figsize=(6, 6))\n",
    "plt.title('Combined Observations as a Single State')\n",
    "plt.imshow(img, interpolation='nearest', cmap='gray')\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are going to build the DQN. Like the DQN for Pac-Man, this model will train 3 convolutional layers, then a hidden fully connected layer, then finally a fully connected layer with 6 neurons, one representing each possible output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [],
   "source": [
    "reset_graph()\n",
    "\n",
    "input_width = 80\n",
    "input_height = 80\n",
    "input_channels = 1\n",
    "\n",
    "conv_n_maps = [32, 64, 64]\n",
    "conv_kernel_sizes = [9, 5, 3]\n",
    "conv_kernel_strides = [4, 2, 1]\n",
    "conv_paddings = ['VALID'] * 3\n",
    "conv_activation = [tf.nn.relu] * 3\n",
    "\n",
    "n_hidden_in = 5 * 5 * 64\n",
    "n_hidden = 512\n",
    "hidden_activation = tf.nn.relu\n",
    "n_outputs = env.action_space.n\n",
    "\n",
    "he_init = tf.contrib.layers.variance_scaling_initializer()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This model will use two DQNs, an online DQN and a target DQN. The online DQN learns new parameters at each training step. The target DQN is used to compute the target Q-Values for the online DQN's loss function during training. The online DQN's parameters are copied to the target DQN at regular intervals."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [],
   "source": [
    "def q_network(X_state, name):\n",
    "    prev_layer = X_state\n",
    "    with tf.variable_scope(name) as scope:\n",
    "        for n_maps, kernel_size, strides, padding, activation in zip(\n",
    "            conv_n_maps, conv_kernel_sizes, conv_kernel_strides, conv_paddings,\n",
    "            conv_activation):\n",
    "            prev_layer = tf.layers.conv2d(prev_layer, filters=n_maps,\n",
    "                                          kernel_size=kernel_size,\n",
    "                                          strides=strides, padding=padding,\n",
    "                                          activation=activation,\n",
    "                                          kernel_initializer=he_init)\n",
    "        flattened = tf.reshape(prev_layer, [-1, n_hidden_in])\n",
    "        hidden = tf.layers.dense(flattened, n_hidden,\n",
    "                                 activation=hidden_activation,\n",
    "                                 kernel_initializer=he_init)\n",
    "        outputs = tf.layers.dense(hidden, n_outputs, kernel_initializer=he_init)\n",
    "    trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,\n",
    "                                       scope=scope.name)\n",
    "    trainable_vars_by_name = {var.name[len(scope.name):]: var\n",
    "                              for var in trainable_vars}\n",
    "    return outputs, trainable_vars_by_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Starting the DQN definition.\n",
    "\n",
    "X_state = tf.placeholder(tf.float32, shape=(None, input_height, input_width,\n",
    "                                            input_channels))\n",
    "online_q_values, online_vars = q_network(X_state, 'q_networks/online')\n",
    "target_q_values, target_vars = q_network(X_state, 'q_networks/target')\n",
    "copy_ops = [var.assign(online_vars[name]) for name, var in target_vars.items()]\n",
    "copy_online_to_target = tf.group(*copy_ops)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Defining the training objective.\n",
    "\n",
    "learning_rate = 1e-3\n",
    "momentum = 0.95\n",
    "\n",
    "with tf.variable_scope('training') as scope:\n",
    "    X_action = tf.placeholder(tf.int32, shape=(None,))\n",
    "    y = tf.placeholder(tf.float32, shape=(None, 1))\n",
    "    Q_target = tf.reduce_sum(online_q_values * tf.one_hot(X_action, n_outputs),\n",
    "                             axis=1, keepdims=True)\n",
    "    error = tf.abs(y - Q_target)\n",
    "    loss = tf.reduce_mean(tf.square(error))\n",
    "\n",
    "    global_step = tf.Variable(0, trainable=False, name='global_step')\n",
    "    optimizer = tf.train.MomentumOptimizer(learning_rate, momentum,\n",
    "                                           use_nesterov=True)\n",
    "    training_op = optimizer.minimize(loss, global_step=global_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [],
   "source": [
    "init = tf.global_variables_initializer()\n",
    "saver = tf.train.Saver()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This model will sample past experiences from a _Replay Memory_, this will hopefully help the model learn what higher level patterns to pay attention to to find the right action. It also reduces the chance that the model's behavior gets too correlated to it's most recent experiences.\n",
    "\n",
    "The replay memory will store its data in the kernel's memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReplayMemory(object):\n",
    "    def __init__(self, maxlen):\n",
    "        self.maxlen = maxlen\n",
    "        self.buf = np.empty(shape=maxlen, dtype=np.object)\n",
    "        self.index = 0\n",
    "        self.length = 0\n",
    "\n",
    "    def append(self, data):\n",
    "        self.buf[self.index] = data\n",
    "        self.index += 1\n",
    "        self.index %= self.maxlen\n",
    "        self.length = min(self.length + 1, self.maxlen)\n",
    "\n",
    "    def sample(self, batch_size):\n",
    "        return self.buf[np.random.randint(self.length, size=batch_size)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [],
   "source": [
    "replay_size = 200000\n",
    "replay_memory = ReplayMemory(replay_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_memories(batch_size):\n",
    "    cols = [[], [], [], [], []]  # state, action, reward, next_state, continue\n",
    "    for memory in replay_memory.sample(batch_size):\n",
    "        for col, value in zip(cols, memory):\n",
    "            col.append(value)\n",
    "    cols = [np.array(col) for col in cols]\n",
    "    return cols[0], cols[1], cols[2].reshape(-1, 1), cols[3], \\\n",
    "         cols[4].reshape(-1, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's define the model's policy during training. Just like in `MsPacMan.ipynb`, we will use an $\\varepsilon$-greedy policy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [],
   "source": [
    "eps_min = 0.1\n",
    "eps_max = 1.0\n",
    "eps_decay_steps = 6000000\n",
    "\n",
    "def epsilon_greedy(q_values, step):\n",
    "    epsilon = min(eps_min,\n",
    "                  eps_max - ((eps_max - eps_min) * (step / eps_decay_steps)))\n",
    "    if np.random.random() < epsilon:\n",
    "        return np.random.randint(n_outputs)\n",
    "    return np.argmax(q_values)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will train the model to play some Pong. The model will input an action once every 3 frames. The preprocessing functions defined above will use the 3 frames to compute the state the model will use to "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_steps = 10000000\n",
    "training_start = 100000\n",
    "training_interval = 4\n",
    "save_steps = 1000\n",
    "copy_steps = 10000\n",
    "discount_rate = 0.95\n",
    "skip_start = 20\n",
    "batch_size = 50\n",
    "iteration = 0\n",
    "done = True  # To reset the environment at the start.\n",
    "\n",
    "loss_val = np.infty\n",
    "game_length = 0\n",
    "total_max_q = 0.0\n",
    "mean_max_q = 0.0\n",
    "\n",
    "checkpoint_path = \"./pong_dqn.ckpt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Utility function to get the environment state for the model.\n",
    "\n",
    "def perform_action(action):\n",
    "    preprocess_observations = []\n",
    "    total_reward = 0.0\n",
    "    for i in range(3):\n",
    "        obs, reward, done, info = env.step(action)\n",
    "        total_reward += reward\n",
    "        if done:\n",
    "            for _ in range(i, 3):\n",
    "                preprocess_observations.append(preprocess_observation(obs))\n",
    "            break\n",
    "        else:\n",
    "            preprocess_observations.append(preprocess_observation(obs))\n",
    "    return combine_observations(preprocess_observations).reshape(80, 80, 1), \\\n",
    "        total_reward, done"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from /content/gdrive/My Drive/models/pong_dqn.ckpt\n",
      "Iteration 1056803\tTraining step 9291202/10000000 (92.9)%\tLoss 0.014324\tMean Max-Q 0.036826   "
     ]
    }
   ],
   "source": [
    "# Main training loop\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    if os.path.isfile(checkpoint_path + '.index'):\n",
    "        saver.restore(sess, checkpoint_path)\n",
    "    else:\n",
    "        init.run()\n",
    "        copy_online_to_target.run()\n",
    "    while True:\n",
    "        step = global_step.eval()\n",
    "        if step >= n_steps:\n",
    "            break\n",
    "        iteration += 1\n",
    "        print('\\rIteration {}\\tTraining step {}/{} ({:.1f})%\\tLoss {:5f}'\n",
    "              '\\tMean Max-Q {:5f}   '.format(\n",
    "                  iteration, step, n_steps, 100 * step / n_steps, loss_val,\n",
    "                  mean_max_q),\n",
    "              end='')\n",
    "        if done:\n",
    "            obs = env.reset()\n",
    "            for _ in range(skip_start):\n",
    "                obs, reward, done, info = env.step(0)\n",
    "            state, reward, done = perform_action(0)\n",
    "\n",
    "        # Evaluate the next action for the agent.\n",
    "        q_values = online_q_values.eval(\n",
    "            feed_dict={X_state: [state]})\n",
    "        action = epsilon_greedy(q_values, step)\n",
    "\n",
    "        # The online DQN plays the game.\n",
    "        next_state, reward, done = perform_action(action)\n",
    "\n",
    "        # Save the result in the ReplayMemory.\n",
    "        replay_memory.append((state, action, reward, next_state, 1.0 - done))\n",
    "        state = next_state\n",
    "\n",
    "        # Compute statistics which help us monitor how training is going.\n",
    "        total_max_q += q_values.max()\n",
    "        game_length += 1\n",
    "        if done:\n",
    "            mean_max_q = total_max_q / game_length\n",
    "            total_max_q = 0.0\n",
    "            game_length = 0\n",
    "\n",
    "        # Only train after the warmup rounds and only every few rounds.\n",
    "        if iteration < training_start or iteration % training_interval != 0:\n",
    "            continue\n",
    "\n",
    "        # Sample memories from the reply memory.\n",
    "        X_state_val, X_action_val, rewards, X_next_state_val, continues = \\\n",
    "            sample_memories(batch_size)\n",
    "        next_q_values = target_q_values.eval(\n",
    "            feed_dict={X_state: X_next_state_val})\n",
    "        max_next_q_values = np.max(next_q_values, axis=1, keepdims=True)\n",
    "        y_val = rewards + continues * discount_rate * max_next_q_values\n",
    "\n",
    "        # Train the online DQN.\n",
    "        _, loss_val = sess.run([training_op, loss], feed_dict={\n",
    "            X_state: X_state_val,\n",
    "            X_action: X_action_val,\n",
    "            y: y_val,\n",
    "        })\n",
    "\n",
    "        # Regularly copy the online DQN to the target DQN.\n",
    "        if step % copy_steps == 0:\n",
    "            copy_online_to_target.run()\n",
    "\n",
    "        # Regularly save the model.\n",
    "        if step and step % save_steps == 0:\n",
    "            saver.save(sess, checkpoint_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use standard file APIs to check for files with this prefix.\n",
      "INFO:tensorflow:Restoring parameters from /content/gdrive/My Drive/models/pong_dqn.ckpt\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAANMAAAEACAYAAAAp2kPsAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAA+tJREFUeJzt3cFNG0EAQFFvRAOJa0gHXOCOlIgW\n6CEShaCkB1pAOXCHCx2kBqcF5xAlkomDAX/ba/zebeX1ei5fs2PPwjCfzyfA+t7tegDwVogJImKC\niJggIiaIiAkiYoKImCAiJoiICSJHux7AMsMwPLnH6eun99saCvz15fvP4anXRxnTJmI5Oz150fm3\nd/drvX/ZNd6qh8vzhePjq5sXnf8cq645Bm7zICImiIgJIqNcM23CqvXLumuq11yD35ath16zrto1\nMxNExAQRMUHkYNZM1jPjtY/ro2XMTBARE0TEBJGDWTM9dij75vaB35mABWKCiJggIiaIHOwXEKt+\nxK03xvJ/+/hlwzJmJoiICSJigsgwxn929u3zh/ENioO36q8TmZkgIiaIjPI2bzabjW9QHLzpdOo2\nD7ZBTBARE0TEBBExQURMEBETRMQEETFBZJQ7IGx0ZYxsdIUtERNExAQRMUFETBARE0TEBBExQURM\nEBETRMQEETFBREwQERNExAQRMUFETBARE0TEBBExQURMEBETRMQEETFBREwQERNExAQRMUFETBAR\nE0TEBBExQURMEBETRMQEETFBREwQERNExAQRMUFETBARE0TEBBExQURMEBETRMQEETFBREwQERNE\nxAQRMUFETBARE0TEBBExQURMEBETRMQEETFBREwQERNExAQRMUFETBARE0TEBBExQURMEBETRMQE\nETFB5GjXA4BNeLg8Xzg+vrrZ+GeamSAiJoiICSJigoiYICImiIgJImKCiJggIiaIiAkiYoKImCAi\nJoiICSJigoiHA3mTtvEw4GNmJoiICSJigoiYICImiIgJImKCiJggIiaIiAkiYoKImCAiJoiICSJi\ngoiYICImiIgJImKCiJggIiaIiAkiYoKImCAiJoiICSJigoiYICImiIgJImKCiJggIiaIiAkiYoKI\nmCAiJoiICSJigoiYICImiIgJImKCiJggIiaIiAkiYoKImCAiJoiICSJigoiYICImiIgJImKCiJgg\nIiaIiAkiYoKImCAiJoiICSJigoiYIHK06wFsytnpycLx7d39jkbCvrq4/jGZTCaT64uPzzrfzAQR\nMUFETBB5s2smWNdz10p/mJkgIiaIiAkiYoKImCAiJogM8/l812P4x2w2G9+gOHjT6XR46nUzE0TE\nBBExQURMEBETRMQEETFBREwQERNExAQRMUFETBARE0TEBBExQURMEBnlw4Gwj8xMEBETRMQEETFB\nREwQERNExAQRMUFETBARE0TEBBExQURMEBETRMQEETFBREwQERNExAQRMUFETBARE0TEBBExQURM\nEBETRH4BDcZWTLXJ8XUAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "preprocess_observations = []\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    saver.restore(sess, checkpoint_path)\n",
    "\n",
    "    def dqn_policy(obs, i):\n",
    "        if len(preprocess_observations) < 3:\n",
    "            preprocess_observations.append(preprocess_observation(obs))\n",
    "            if len(preprocess_observations) == 3:\n",
    "                state = combine_observations(preprocess_observations)\n",
    "                q_values = online_q_values.eval(\n",
    "                    feed_dict={X_state: [state.reshape(80, 80, 1)]})\n",
    "                dqn_policy.cur_action = np.argmax(q_values)\n",
    "            return dqn_policy.cur_action\n",
    "        preprocess_observations[i % 3] = preprocess_observation(obs)\n",
    "        if i % 3 == 2:\n",
    "            state = combine_observations(preprocess_observations)\n",
    "            q_values = online_q_values.eval(\n",
    "                feed_dict={X_state: [state.reshape(80, 80, 1)]})\n",
    "            dqn_policy.cur_action = np.argmax(q_values)\n",
    "        return dqn_policy.cur_action\n",
    "    dqn_policy.cur_action = 0\n",
    "\n",
    "    html = run_episode(dqn_policy, n_max_steps=10000)\n",
    "html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Special thanks to [Dylan Cutler](https://github.com/DCtheTall) for contributing this solution!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
